// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <map>
#include <limits>
#include <string>
#include <queue>
#include <utility>
#include <vector>
#include <boost/align/aligned_allocator.hpp>
#include <boost/dynamic_bitset.hpp>
#include <NamedType/named_type.hpp>

#include "common/FieldMeta.h"
#include "common/ArrayOffsets.h"
#include "common/OffsetMapping.h"
#include "query/Utils.h"
#include "pb/schema.pb.h"
#include "knowhere/index/index_node.h"

namespace milvus {

// scan cost in each search/query
struct StorageCost {
    int64_t scanned_remote_bytes = 0;
    int64_t scanned_total_bytes = 0;

    StorageCost() = default;

    StorageCost(int64_t scanned_remote_bytes, int64_t scanned_total_bytes)
        : scanned_remote_bytes(scanned_remote_bytes),
          scanned_total_bytes(scanned_total_bytes) {
    }

    StorageCost
    operator+(const StorageCost& rhs) const {
        return {scanned_remote_bytes + rhs.scanned_remote_bytes,
                scanned_total_bytes + rhs.scanned_total_bytes};
    }

    void
    operator+=(const StorageCost& rhs) {
        scanned_remote_bytes += rhs.scanned_remote_bytes;
        scanned_total_bytes += rhs.scanned_total_bytes;
    }

    StorageCost
    operator*(const double factor) const {
        return {static_cast<int64_t>(scanned_remote_bytes * factor),
                static_cast<int64_t>(scanned_total_bytes * factor)};
    }

    void
    operator*=(const double factor) {
        scanned_remote_bytes =
            static_cast<int64_t>(scanned_remote_bytes * factor);
        scanned_total_bytes =
            static_cast<int64_t>(scanned_total_bytes * factor);
    }

    void
    operator=(const StorageCost& rhs) {
        scanned_remote_bytes = rhs.scanned_remote_bytes;
        scanned_total_bytes = rhs.scanned_total_bytes;
    }

    std::string
    ToString() const {
        return fmt::format("scanned_remote_bytes: {}, scanned_total_bytes: {}",
                           scanned_remote_bytes,
                           scanned_total_bytes);
    }
};

inline std::ostream&
operator<<(std::ostream& os, const StorageCost& cost) {
    os << cost.ToString();
    return os;
}

struct OffsetDisPair {
 private:
    std::pair<int64_t, float> off_dis_;
    int iterator_idx_;

 public:
    OffsetDisPair(std::pair<int64_t, float> off_dis, int iter_idx)
        : off_dis_(off_dis), iterator_idx_(iter_idx) {
    }

    const std::pair<int64_t, float>&
    GetOffDis() const {
        return off_dis_;
    }

    int
    GetIteratorIdx() const {
        return iterator_idx_;
    }
};

struct OffsetDisPairComparator {
    bool larger_is_closer_ = false;

    OffsetDisPairComparator(bool larger_is_closer = false)
        : larger_is_closer_(larger_is_closer) {
    }

    bool
    operator()(const std::shared_ptr<OffsetDisPair>& left,
               const std::shared_ptr<OffsetDisPair>& right) const {
        // For priority_queue: return true if left has lower priority than right
        // We want the element with better (closer) distance at the top
        if (left->GetOffDis().second != right->GetOffDis().second) {
            if (larger_is_closer_) {
                // IP/Cosine: larger distance is better, smaller has lower priority
                return left->GetOffDis().second < right->GetOffDis().second;
            } else {
                // L2: smaller distance is better, larger has lower priority
                return left->GetOffDis().second > right->GetOffDis().second;
            }
        }
        return left->GetOffDis().first < right->GetOffDis().first;
    }
};

class VectorIterator {
 public:
    virtual ~VectorIterator() = default;

    virtual bool
    HasNext() = 0;

    virtual std::optional<std::pair<int64_t, float>>
    Next() = 0;
};

// Multi-way merge iterator for vector search results from multiple chunks
//
// Merges knowhere iterators from different chunks using a min-heap,
// returning results in distance-sorted order.
class ChunkMergeIterator : public VectorIterator {
 public:
    ChunkMergeIterator(int chunk_count,
                       const milvus::OffsetMapping& offset_mapping,
                       const std::vector<int64_t>& total_rows_until_chunk = {},
                       bool larger_is_closer = false)
        : offset_mapping_(&offset_mapping),
          larger_is_closer_(larger_is_closer),
          heap_(OffsetDisPairComparator(larger_is_closer)) {
        iterators_.reserve(chunk_count);
    }

    bool
    HasNext() override {
        return !heap_.empty();
    }

    std::optional<std::pair<int64_t, float>>
    Next() override {
        if (!heap_.empty()) {
            auto top = heap_.top();
            heap_.pop();
            if (iterators_[top->GetIteratorIdx()]->HasNext()) {
                auto origin_pair = iterators_[top->GetIteratorIdx()]->Next();
                auto off_dis_pair = std::make_shared<OffsetDisPair>(
                    origin_pair, top->GetIteratorIdx());
                heap_.push(off_dis_pair);
            }
            auto result = top->GetOffDis();
            if (offset_mapping_ != nullptr) {
                result.first = offset_mapping_->GetLogicalOffset(result.first);
            }
            return result;
        }
        return std::nullopt;
    }

    bool
    AddIterator(knowhere::IndexNode::IteratorPtr iter) {
        if (!sealed && iter != nullptr) {
            iterators_.emplace_back(iter);
            return true;
        }
        return false;
    }

    void
    seal() {
        sealed = true;
        int idx = 0;
        for (auto& iter : iterators_) {
            if (iter->HasNext()) {
                auto origin_pair = iter->Next();
                auto off_dis_pair =
                    std::make_shared<OffsetDisPair>(origin_pair, idx++);
                heap_.push(off_dis_pair);
            }
        }
    }

 private:
    int64_t
    convert_to_segment_offset(int64_t chunk_offset, int chunk_idx) {
        if (total_rows_until_chunk_.size() == 0) {
            AssertInfo(
                iterators_.size() == 1,
                "Wrong state for vectorIterators, which having incorrect "
                "kw_iterator count:{} "
                "without setting value for chunk_rows, "
                "cannot convert chunk_offset to segment_offset correctly",
                iterators_.size());
            return chunk_offset;
        }
        return total_rows_until_chunk_[chunk_idx] + chunk_offset;
    }

 private:
    std::vector<knowhere::IndexNode::IteratorPtr> iterators_;
    std::priority_queue<std::shared_ptr<OffsetDisPair>,
                        std::vector<std::shared_ptr<OffsetDisPair>>,
                        OffsetDisPairComparator>
        heap_;
    bool sealed = false;
    const milvus::OffsetMapping* offset_mapping_ = nullptr;
    std::vector<int64_t> total_rows_until_chunk_;
    bool larger_is_closer_ = false;
    //currently, ChunkMergeIterator is guaranteed to be used serially without concurrent problem, in the future
    //we may need to add mutex to protect the variable sealed
};

struct SearchResult {
    SearchResult() = default;

    int64_t
    get_total_result_count() const {
        if (topk_per_nq_prefix_sum_.empty()) {
            return 0;
        }
        AssertInfo(topk_per_nq_prefix_sum_.size() == total_nq_ + 1,
                   "wrong topk_per_nq_prefix_sum_ size {}",
                   topk_per_nq_prefix_sum_.size());
        return topk_per_nq_prefix_sum_[total_nq_];
    }

 public:
    void
    AssembleChunkVectorIterators(
        int64_t nq,
        int chunk_count,
        const std::vector<int64_t>& total_rows_until_chunk,
        const std::vector<knowhere::IndexNode::IteratorPtr>& kw_iterators,
        const milvus::OffsetMapping& offset_mapping,
        bool larger_is_closer = false) {
        AssertInfo(kw_iterators.size() == nq * chunk_count,
                   "kw_iterators count:{} is not equal to nq*chunk_count:{}, "
                   "wrong state",
                   kw_iterators.size(),
                   nq * chunk_count);
        std::vector<std::shared_ptr<VectorIterator>> vector_iterators;
        vector_iterators.reserve(nq);
        for (int i = 0, vec_iter_idx = 0; i < kw_iterators.size(); i++) {
            vec_iter_idx = vec_iter_idx % nq;
            if (vector_iterators.size() < nq) {
                auto chunk_merge_iter =
                    std::make_shared<ChunkMergeIterator>(chunk_count,
                                                         offset_mapping,
                                                         total_rows_until_chunk,
                                                         larger_is_closer);
                vector_iterators.emplace_back(chunk_merge_iter);
            }
            const auto& kw_iterator = kw_iterators[i];
            auto chunk_merge_iter =
                std::static_pointer_cast<ChunkMergeIterator>(
                    vector_iterators[vec_iter_idx++]);
            chunk_merge_iter->AddIterator(kw_iterator);
        }
        for (const auto& vector_iter : vector_iterators) {
            // Cast to ChunkMergeIterator to call seal
            auto chunk_merge_iter =
                std::static_pointer_cast<ChunkMergeIterator>(vector_iter);
            chunk_merge_iter->seal();
        }
        this->vector_iterators_ = vector_iterators;
    }

 public:
    int64_t total_nq_;
    int64_t unity_topK_;
    int64_t total_data_cnt_;
    void* segment_;

    // first fill data during search, and then update data after reducing search results
    std::vector<float> distances_;
    std::vector<int64_t> seg_offsets_;
    std::optional<std::vector<GroupByValueType>> group_by_values_;
    std::optional<int64_t> group_size_;

    // first fill data during fillPrimaryKey, and then update data after reducing search results
    std::vector<PkType> primary_keys_;
    DataType pk_type_;

    // fill data during reducing search result
    std::vector<int64_t> result_offsets_;
    // after reducing search result done, size(distances_) = size(seg_offsets_) = size(primary_keys_) =
    // size(primary_keys_)

    // set output fields data when fill target entity
    std::map<FieldId, std::unique_ptr<milvus::DataArray>> output_fields_data_;

    // used for reduce, filter invalid pk, get real topks count
    std::vector<size_t> topk_per_nq_prefix_sum_{};

    //Vector iterators, used for group by
    std::optional<std::vector<std::shared_ptr<VectorIterator>>>
        vector_iterators_;
    // record the storage usage in search
    StorageCost search_storage_cost_;

    bool element_level_{false};
    std::vector<int32_t> element_indices_;
    std::optional<std::vector<std::shared_ptr<VectorIterator>>>
        element_iterators_;
    std::shared_ptr<const IArrayOffsets> array_offsets_{nullptr};
    std::vector<std::unique_ptr<uint8_t[]>> chunk_buffers_{};

    bool
    HasIterators() const {
        return (element_level_ && element_iterators_.has_value()) ||
               (!element_level_ && vector_iterators_.has_value());
    }

    std::optional<std::vector<std::shared_ptr<VectorIterator>>>
    GetIterators() {
        if (element_level_) {
            return element_iterators_;
        } else {
            return vector_iterators_;
        }
    }
};

using SearchResultPtr = std::shared_ptr<SearchResult>;
using SearchResultOpt = std::optional<SearchResult>;

struct RetrieveResult {
    RetrieveResult() = default;

 public:
    int64_t total_data_cnt_;
    void* segment_;
    std::vector<int64_t> result_offsets_;
    std::vector<DataArray> field_data_;
    bool has_more_result = true;
    // record the storage usage in retrieve
    StorageCost retrieve_storage_cost_;
};

using RetrieveResultPtr = std::shared_ptr<RetrieveResult>;
using RetrieveResultOpt = std::optional<RetrieveResult>;
}  // namespace milvus
