// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "index/VectorDiskIndex.h"

#include "common/Tracer.h"
#include "common/Types.h"
#include "common/Utils.h"
#include "config/ConfigKnowhere.h"
#include "index/Meta.h"
#include "index/Utils.h"
#include "storage/LocalChunkManagerSingleton.h"
#include "storage/Util.h"
#include "common/Consts.h"
#include "common/RangeSearchHelper.h"
#include "indexbuilder/types.h"
#include "filemanager/FileManager.h"
#include "log/Log.h"

namespace milvus::index {

#define kSearchListMaxValue1 200    // used if tok <= 20
#define kSearchListMaxValue2 65535  // used for topk > 20
#define kPrepareDim 100
#define kPrepareRows 1

template <typename T>
VectorDiskAnnIndex<T>::VectorDiskAnnIndex(
    DataType elem_type,
    const IndexType& index_type,
    const MetricType& metric_type,
    const IndexVersion& version,
    const storage::FileManagerContext& file_manager_context)
    : VectorIndex(index_type, metric_type), elem_type_(elem_type) {
    CheckMetricTypeSupport<T>(metric_type);
    file_manager_ =
        std::make_shared<storage::DiskFileManagerImpl>(file_manager_context);
    AssertInfo(file_manager_ != nullptr, "create file manager failed!");
    auto local_chunk_manager =
        storage::LocalChunkManagerSingleton::GetInstance().GetChunkManager();
    auto local_index_path_prefix = file_manager_->GetLocalIndexObjectPrefix();

    // As we have guarded dup-load in QueryNode,
    // this assertion failed only if the Milvus rebooted in the same pod,
    // need to remove these files then re-load the segment
    if (local_chunk_manager->Exist(local_index_path_prefix)) {
        local_chunk_manager->RemoveDir(local_index_path_prefix);
    }
    CheckCompatible(version);
    local_chunk_manager->CreateDir(local_index_path_prefix);
    auto diskann_index_pack =
        knowhere::Pack(std::shared_ptr<milvus::FileManager>(file_manager_));
    auto get_index_obj = knowhere::IndexFactory::Instance().Create<T>(
        GetIndexType(), version, diskann_index_pack);
    if (get_index_obj.has_value()) {
        index_ = get_index_obj.value();
    } else {
        auto err = get_index_obj.error();
        if (err == knowhere::Status::invalid_index_error) {
            ThrowInfo(ErrorCode::Unsupported, get_index_obj.what());
        }
        ThrowInfo(ErrorCode::KnowhereError, get_index_obj.what());
    }
}

template <typename T>
void
VectorDiskAnnIndex<T>::Load(const BinarySet& binary_set /* not used */,
                            const Config& config) {
    Load(milvus::tracer::TraceContext{}, config);
}

template <typename T>
void
VectorDiskAnnIndex<T>::Load(milvus::tracer::TraceContext ctx,
                            const Config& config) {
    knowhere::Json load_config = update_load_json(config);

    // start read file span with active scope
    {
        auto read_file_span =
            milvus::tracer::StartSpan("SegCoreReadDiskIndexFile", &ctx);
        auto read_scope =
            milvus::tracer::GetTracer()->WithActiveSpan(read_file_span);
        auto index_files =
            GetValueFromConfig<std::vector<std::string>>(config, "index_files");
        AssertInfo(index_files.has_value(),
                   "index file paths is empty when load disk ann index data");
        // If index is loaded with stream, we don't need to cache index to disk
        if (!index_.LoadIndexWithStream()) {
            auto load_priority =
                GetValueFromConfig<milvus::proto::common::LoadPriority>(
                    config, milvus::LOAD_PRIORITY)
                    .value_or(milvus::proto::common::LoadPriority::HIGH);
            file_manager_->CacheIndexToDisk(index_files.value(), load_priority);
        }
        read_file_span->End();
    }

    // start engine load index span
    auto span_load_engine =
        milvus::tracer::StartSpan("SegCoreEngineLoadDiskIndex", &ctx);
    auto engine_scope =
        milvus::tracer::GetTracer()->WithActiveSpan(span_load_engine);
    auto stat = index_.Deserialize(knowhere::BinarySet(), load_config);
    if (stat != knowhere::Status::success)
        ThrowInfo(ErrorCode::UnexpectedError,
                  "failed to Deserialize index, " + KnowhereStatusString(stat));
    span_load_engine->End();

    auto local_chunk_manager =
        storage::LocalChunkManagerSingleton::GetInstance().GetChunkManager();
    auto local_index_path_prefix = file_manager_->GetLocalIndexObjectPrefix();

    auto valid_data_path = local_index_path_prefix + "/" + VALID_DATA_KEY;
    if (local_chunk_manager->Exist(valid_data_path)) {
        size_t count;
        local_chunk_manager->Read(valid_data_path, 0, &count, sizeof(size_t));
        size_t byte_size = (count + 7) / 8;
        std::vector<uint8_t> valid_bitmap(byte_size);
        local_chunk_manager->Read(
            valid_data_path, sizeof(size_t), valid_bitmap.data(), byte_size);
        // Convert bitmap to bool array
        std::unique_ptr<bool[]> valid_data(new bool[count]);
        for (size_t i = 0; i < count; ++i) {
            valid_data[i] = (valid_bitmap[i / 8] >> (i % 8)) & 1;
        }
        BuildValidData(valid_data.get(), count);
    }

    SetDim(index_.Dim());
}

template <typename T>
IndexStatsPtr
VectorDiskAnnIndex<T>::Upload(const Config& config) {
    BinarySet ret;
    auto stat = index_.Serialize(ret);
    if (stat != knowhere::Status::success) {
        ThrowInfo(ErrorCode::UnexpectedError,
                  "failed to serialize index, " + KnowhereStatusString(stat));
    }
    auto remote_paths_to_size = file_manager_->GetRemotePathsToFileSize();
    return IndexStats::NewFromSizeMap(file_manager_->GetAddedTotalFileSize(),
                                      remote_paths_to_size);
}

template <typename T>
void
VectorDiskAnnIndex<T>::Build(const Config& config) {
    LOG_INFO("start build disk index, build_id: {}",
             config.value("build_id", "unknown"));

    auto local_chunk_manager =
        storage::LocalChunkManagerSingleton::GetInstance().GetChunkManager();
    knowhere::Json build_config;
    build_config.update(config);

    auto segment_id = file_manager_->GetFieldDataMeta().segment_id;
    auto field_id = file_manager_->GetFieldDataMeta().field_id;

    auto is_embedding_list = (elem_type_ != DataType::NONE);
    Config config_with_emb_list = config;
    config_with_emb_list[EMB_LIST] = is_embedding_list;

    std::string offsets_path;
    // Set offsets path in config for VECTOR_ARRAY
    if (is_embedding_list) {
        offsets_path = storage::GenFieldRawDataPathPrefix(
                           local_chunk_manager, segment_id, field_id) +
                       "offset";
        config_with_emb_list[EMB_LIST_OFFSETS_PATH] = offsets_path;
    }

    // Set valid data path to track nullable vector fields
    auto local_index_path_prefix = file_manager_->GetLocalIndexObjectPrefix();
    auto valid_data_path = local_index_path_prefix + "/" + VALID_DATA_KEY;
    config_with_emb_list[VALID_DATA_PATH_KEY] = valid_data_path;

    auto local_data_path =
        file_manager_->CacheRawDataToDisk<T>(config_with_emb_list);
    build_config[DISK_ANN_RAW_DATA_PATH] = local_data_path;

    // For VECTOR_ARRAY, verify offsets file exists and pass its path to build_config
    if (is_embedding_list) {
        if (!local_chunk_manager->Exist(offsets_path)) {
            ThrowInfo(ErrorCode::UnexpectedError,
                      fmt::format("Embedding list offsets file not found: {}",
                                  offsets_path));
        }
        build_config[EMB_LIST_OFFSETS_PATH] = offsets_path;
    }

    build_config[DISK_ANN_PREFIX_PATH] = local_index_path_prefix;

    if (GetIndexType() == knowhere::IndexEnum::INDEX_DISKANN) {
        auto num_threads = GetValueFromConfig<std::string>(
            build_config, DISK_ANN_BUILD_THREAD_NUM);
        AssertInfo(
            num_threads.has_value(),
            "param " + std::string(DISK_ANN_BUILD_THREAD_NUM) + "is empty");
        build_config[DISK_ANN_THREADS_NUM] =
            std::atoi(num_threads.value().c_str());
    }

    auto opt_fields = GetValueFromConfig<OptFieldT>(config, VEC_OPT_FIELDS);
    auto is_partition_key_isolation =
        GetValueFromConfig<bool>(build_config, "partition_key_isolation");
    if (opt_fields.has_value() &&
        index_.IsAdditionalScalarSupported(
            is_partition_key_isolation.value_or(false))) {
        build_config[VEC_OPT_FIELDS_PATH] =
            file_manager_->CacheOptFieldToDisk(config);
        // `partition_key_isolation` is already in the config, so it falls through
        // into the index Build call directly
    }

    build_config.erase(INSERT_FILES_KEY);
    build_config.erase(VEC_OPT_FIELDS);
    auto stat = index_.Build({}, build_config);
    if (stat != knowhere::Status::success)
        ThrowInfo(ErrorCode::IndexBuildError,
                  "failed to build disk index, " + KnowhereStatusString(stat));

    // Add valid_data file to index if it was created (nullable vector field)
    if (local_chunk_manager->Exist(valid_data_path)) {
        file_manager_->AddFile(valid_data_path);
    }

    local_chunk_manager->RemoveDir(storage::GenFieldRawDataPathPrefix(
        local_chunk_manager, segment_id, field_id));

    LOG_INFO("build disk index done, build_id: {}",
             config.value("build_id", "unknown"));
}

template <typename T>
void
VectorDiskAnnIndex<T>::BuildWithDataset(const DatasetPtr& dataset,
                                        const Config& config) {
    auto local_chunk_manager =
        storage::LocalChunkManagerSingleton::GetInstance().GetChunkManager();
    knowhere::Json build_config;
    build_config.update(config);
    // set data path
    auto segment_id = file_manager_->GetFieldDataMeta().segment_id;
    auto field_id = file_manager_->GetFieldDataMeta().field_id;
    auto local_data_path = storage::GenFieldRawDataPathPrefix(
                               local_chunk_manager, segment_id, field_id) +
                           "raw_data";
    build_config[DISK_ANN_RAW_DATA_PATH] = local_data_path;

    auto local_index_path_prefix = file_manager_->GetLocalIndexObjectPrefix();
    build_config[DISK_ANN_PREFIX_PATH] = local_index_path_prefix;

    if (GetIndexType() == knowhere::IndexEnum::INDEX_DISKANN) {
        auto num_threads = GetValueFromConfig<std::string>(
            build_config, DISK_ANN_BUILD_THREAD_NUM);
        AssertInfo(
            num_threads.has_value(),
            "param " + std::string(DISK_ANN_BUILD_THREAD_NUM) + "is empty");
        build_config[DISK_ANN_THREADS_NUM] =
            std::atoi(num_threads.value().c_str());
    }
    if (!local_chunk_manager->Exist(local_data_path)) {
        local_chunk_manager->CreateFile(local_data_path);
    }

    int64_t offset = 0;
    auto num = uint32_t(milvus::GetDatasetRows(dataset));
    local_chunk_manager->Write(local_data_path, offset, &num, sizeof(num));
    offset += sizeof(num);

    auto dim = uint32_t(milvus::GetDatasetDim(dataset));
    local_chunk_manager->Write(local_data_path, offset, &dim, sizeof(dim));
    offset += sizeof(dim);

    size_t data_size = static_cast<size_t>(num) * milvus::GetVecRowSize<T>(dim);
    auto raw_data = const_cast<void*>(milvus::GetDatasetTensor(dataset));
    local_chunk_manager->Write(local_data_path, offset, raw_data, data_size);

    // For VECTOR_ARRAY, write offsets to a separate file and pass the path to knowhere
    if (elem_type_ != DataType::NONE) {
        auto offsets =
            dataset->Get<const size_t*>(knowhere::meta::EMB_LIST_OFFSET);
        if (offsets == nullptr) {
            ThrowInfo(ErrorCode::UnexpectedError,
                      "Embedding list offsets is empty when build index");
        }

        // Write offsets to disk file (use same path convention as Build method)
        std::string offsets_path =
            storage::GenFieldRawDataPathPrefix(
                local_chunk_manager, segment_id, field_id) +
            "offset";
        local_chunk_manager->CreateFile(offsets_path);

        // Calculate the number of offsets (num_rows + 1)
        // We need to find the actual number by looking at the data
        uint32_t num_rows =
            static_cast<uint32_t>(milvus::GetDatasetRows(dataset));
        uint32_t num_offsets = num_rows + 1;

        // Write offsets to file
        // Format: [num_offsets][offsets_data]
        int64_t write_pos = 0;
        local_chunk_manager->Write(
            offsets_path, write_pos, &num_offsets, sizeof(uint32_t));
        write_pos += sizeof(uint32_t);

        local_chunk_manager->Write(
            offsets_path,
            write_pos,
            const_cast<void*>(static_cast<const void*>(offsets)),
            num_offsets * sizeof(size_t));

        build_config[EMB_LIST_OFFSETS_PATH] = offsets_path;
    }

    auto stat = index_.Build({}, build_config);
    if (stat != knowhere::Status::success)
        ThrowInfo(ErrorCode::IndexBuildError,
                  "failed to build index, " + KnowhereStatusString(stat));

    if (HasValidData()) {
        auto valid_data_path = local_index_path_prefix + "/" + VALID_DATA_KEY;
        size_t count = offset_mapping_.GetTotalCount();
        local_chunk_manager->Write(valid_data_path, 0, &count, sizeof(size_t));
        size_t byte_size = (count + 7) / 8;
        std::vector<uint8_t> packed_data(byte_size, 0);
        for (size_t i = 0; i < count; ++i) {
            if (offset_mapping_.IsValid(i)) {
                packed_data[i / 8] |= (1 << (i % 8));
            }
        }
        local_chunk_manager->Write(
            valid_data_path, sizeof(size_t), packed_data.data(), byte_size);
        file_manager_->AddFile(valid_data_path);
    }

    local_chunk_manager->RemoveDir(storage::GenFieldRawDataPathPrefix(
        local_chunk_manager, segment_id, field_id));

    // TODO ::
    // SetDim(index_->Dim());
}

template <typename T>
void
VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset,
                             const SearchInfo& search_info,
                             const BitsetView& bitset,
                             milvus::OpContext* op_context,
                             SearchResult& search_result) const {
    AssertInfo(GetMetricType() == search_info.metric_type_,
               "Metric type of field index isn't the same with search info");
    auto num_rows = dataset->GetRows();
    auto topk = search_info.topk_;

    knowhere::Json search_config = PrepareSearchParams(search_info);

    if (GetIndexType() == knowhere::IndexEnum::INDEX_DISKANN) {
        // set search list size
        if (CheckKeyInConfig(search_info.search_params_, DISK_ANN_QUERY_LIST)) {
            search_config[DISK_ANN_SEARCH_LIST_SIZE] =
                search_info.search_params_[DISK_ANN_QUERY_LIST];
        }
        // set beamwidth
        search_config[DISK_ANN_QUERY_BEAMWIDTH] = int(search_beamwidth_);
        // set json reset field, will be removed later
        search_config[DISK_ANN_PQ_CODE_BUDGET] = 0.0;
    }

    // set index prefix, will be removed later
    auto local_index_path_prefix = file_manager_->GetLocalIndexObjectPrefix();
    search_config[DISK_ANN_PREFIX_PATH] = local_index_path_prefix;

    auto final = [&] {
        if (CheckAndUpdateKnowhereRangeSearchParam(
                search_info, topk, GetMetricType(), search_config)) {
            auto res =
                index_.RangeSearch(dataset, search_config, bitset, op_context);
            if (!res.has_value()) {
                ThrowInfo(ErrorCode::UnexpectedError,
                          fmt::format("failed to range search: {}: {}",
                                      KnowhereStatusString(res.error()),
                                      res.what()));
            }
            return ReGenRangeSearchResult(
                res.value(), topk, num_rows, GetMetricType());
        } else {
            auto res =
                index_.Search(dataset, search_config, bitset, op_context);
            if (!res.has_value()) {
                ThrowInfo(ErrorCode::UnexpectedError,
                          fmt::format("failed to search: {}: {}",
                                      KnowhereStatusString(res.error()),
                                      res.what()));
            }
            return res.value();
        }
    }();

    auto ids = final->GetIds();
    // In embedding list query, final->GetRows() can be different from dataset->GetRows().
    auto num_queries = final->GetRows();
    float* distances = const_cast<float*>(final->GetDistance());
    final->SetIsOwner(true);

    auto round_decimal = search_info.round_decimal_;
    auto total_num = num_queries * topk;

    if (round_decimal != -1) {
        const float multiplier = pow(10.0, round_decimal);
        for (int i = 0; i < total_num; i++) {
            distances[i] = std::round(distances[i] * multiplier) / multiplier;
        }
    }
    search_result.seg_offsets_.resize(total_num);
    search_result.distances_.resize(total_num);
    search_result.total_nq_ = num_queries;
    search_result.unity_topK_ = topk;
    std::copy_n(ids, total_num, search_result.seg_offsets_.data());
    std::copy_n(distances, total_num, search_result.distances_.data());
}

template <typename T>
knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>>
VectorDiskAnnIndex<T>::VectorIterators(const DatasetPtr dataset,
                                       const knowhere::Json& conf,
                                       const BitsetView& bitset) const {
    return this->index_.AnnIterator(dataset, conf, bitset);
}

template <typename T>
const bool
VectorDiskAnnIndex<T>::HasRawData() const {
    return index_.HasRawData(GetMetricType());
}

template <typename T>
std::vector<uint8_t>
VectorDiskAnnIndex<T>::GetVector(const DatasetPtr dataset) const {
    auto index_type = GetIndexType();
    if (IndexIsSparse(index_type)) {
        ThrowInfo(ErrorCode::UnexpectedError,
                  "failed to get vector, index is sparse");
    }

    // if dataset is empty, return empty vector
    if (dataset->GetRows() == 0) {
        return {};
    }

    auto res = index_.GetVectorByIds(dataset);
    if (!res.has_value()) {
        ThrowInfo(ErrorCode::UnexpectedError,
                  fmt::format("failed to get vector: {}: {}",
                              KnowhereStatusString(res.error()),
                              res.what()));
    }
    auto tensor = res.value()->GetTensor();
    auto row_num = res.value()->GetRows();
    auto dim = res.value()->GetDim();
    int64_t data_size = milvus::GetVecRowSize<T>(dim) * row_num;
    std::vector<uint8_t> raw_data;
    raw_data.resize(data_size);
    memcpy(raw_data.data(), tensor, data_size);
    return raw_data;
}

template <typename T>
void
VectorDiskAnnIndex<T>::CleanLocalData() {
    auto local_chunk_manager =
        storage::LocalChunkManagerSingleton::GetInstance().GetChunkManager();
    local_chunk_manager->RemoveDir(file_manager_->GetLocalIndexObjectPrefix());
    local_chunk_manager->RemoveDir(
        file_manager_->GetLocalRawDataObjectPrefix());
}

template <typename T>
inline knowhere::Json
VectorDiskAnnIndex<T>::update_load_json(const Config& config) {
    knowhere::Json load_config;
    load_config.update(config);

    // set data path
    auto local_index_path_prefix = file_manager_->GetLocalIndexObjectPrefix();
    load_config[DISK_ANN_PREFIX_PATH] = local_index_path_prefix;

    if (GetIndexType() == knowhere::IndexEnum::INDEX_DISKANN) {
        // set base info
        load_config[DISK_ANN_PREPARE_WARM_UP] = false;
        load_config[DISK_ANN_PREPARE_USE_BFS_CACHE] = false;

        // set threads number
        auto num_threads = GetValueFromConfig<std::string>(
            load_config, DISK_ANN_LOAD_THREAD_NUM);
        AssertInfo(
            num_threads.has_value(),
            "param " + std::string(DISK_ANN_LOAD_THREAD_NUM) + "is empty");
        load_config[DISK_ANN_THREADS_NUM] =
            std::atoi(num_threads.value().c_str());

        // update search_beamwidth
        auto beamwidth = GetValueFromConfig<std::string>(
            load_config, DISK_ANN_QUERY_BEAMWIDTH);
        if (beamwidth.has_value()) {
            search_beamwidth_ = std::atoi(beamwidth.value().c_str());
        }
    }

    if (config.contains(MMAP_FILE_PATH)) {
        load_config.erase(MMAP_FILE_PATH);
        load_config[ENABLE_MMAP] = true;
    }

    return load_config;
}

template class VectorDiskAnnIndex<float>;
template class VectorDiskAnnIndex<float16>;
template class VectorDiskAnnIndex<bfloat16>;
template class VectorDiskAnnIndex<bin1>;
template class VectorDiskAnnIndex<sparse_u32_f32>;
template class VectorDiskAnnIndex<int8>;

}  // namespace milvus::index
