// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "HashTable.h"
#include <cstring>
#include <memory>
#include <new>
#include "common/SimdUtil.h"

namespace milvus {
namespace exec {
void
BaseHashTable::prepareForGroupProbe(HashLookup& lookup,
                                    const RowVectorPtr& input) {
    auto& hashers = lookup.hashers_;
    int numKeys = hashers.size();
    // set up column vector to each column
    for (auto i = 0; i < numKeys; i++) {
        auto& hasher = hashers[i];
        auto column_idx = hasher->ChannelIndex();
        ColumnVectorPtr column_ptr =
            std::dynamic_pointer_cast<ColumnVector>(input->child(column_idx));
        AssertInfo(column_ptr != nullptr,
                   "Failed to get column vector from row vector input");
        hashers[i]->setColumnData(column_ptr);
    }
    lookup.reset(input->size());

    const auto mode = hashMode();
    for (auto i = 0; i < hashers.size(); i++) {
        if (mode == BaseHashTable::HashMode::kHash) {
            hashers[i]->hash(i > 0, lookup.hashes_);
        } else {
            ThrowInfo(
                milvus::OpTypeInvalid,
                "Not support target hashMode, only support kHash for now");
        }
    }
}

class ProbeState {
 public:
    enum class Operation { kProbe, kInsert, kErase };
    static constexpr int32_t kFullMask = 0xffff;

    int32_t
    row() const {
        return row_;
    }

    template <typename Table>
    inline void
    preProbe(const Table& table, uint64_t hash, int32_t row) {
        row_ = row;
        bucketOffset_ = table.bucketOffset(hash);
        const auto tag = BaseHashTable::hashTag(hash);
        wantedTags_ = BaseHashTable::TagVector::broadcast(tag);
        group_ = nullptr;
        __builtin_prefetch(reinterpret_cast<uint8_t*>(table.table_) +
                           bucketOffset_);
    }

    template <Operation op = Operation::kInsert, typename Table>
    inline void
    firstProbe(const Table& table) {
        tagsInTable_ = BaseHashTable::loadTags(
            reinterpret_cast<uint8_t*>(table.table_), bucketOffset_);
        hits_ = milvus::toBitMask(tagsInTable_ == wantedTags_);
        if (hits_) {
            loadNextHit<op>(table);
        }
    }

    template <Operation op, typename Compare, typename Insert, typename Table>
    inline char*
    fullProbe(Table& table, Compare compare, Insert insert) {
        AssertInfo(op == Operation::kInsert,
                   "Only support insert operation for group cases");
        if (group_ && compare(group_, row_)) {
            return group_;
        }
        const auto kEmptyGroup = BaseHashTable::TagVector::broadcast(0);
        for (int64_t numProbedBuckets = 0;
             numProbedBuckets < table.numBuckets();
             ++numProbedBuckets) {
            while (hits_ > 0) {
                loadNextHit<op>(table);
                if (compare(group_, row_)) {
                    return group_;
                }
            }

            uint16_t empty =
                milvus::toBitMask(tagsInTable_ == kEmptyGroup) & kFullMask;
            // if there are still empty slot available, try to insert into existing empty slot or tombstone slot
            if (empty > 0) {
                auto pos = milvus::bits::getAndClearLastSetBit(empty);
                return insert(row_, bucketOffset_ + pos);
            }
            bucketOffset_ = table.nextBucketOffset(bucketOffset_);
            tagsInTable_ = table.loadTags(bucketOffset_);
            hits_ = milvus::toBitMask(tagsInTable_ == wantedTags_);
        }
        ThrowInfo(UnexpectedError,
                  "Slots in hash table is not enough for hash operation, fail "
                  "the request");
    }

 private:
    static constexpr uint8_t kNotSet = 0xff;
    template <Operation op, typename Table>
    inline void
    loadNextHit(Table& table) {
        const int32_t hit = milvus::bits::getAndClearLastSetBit(hits_);
        group_ = table.row(bucketOffset_, hit);
        __builtin_prefetch(group_);
    }

    char* group_;
    BaseHashTable::TagVector wantedTags_;
    BaseHashTable::TagVector tagsInTable_;
    int32_t row_;
    int64_t bucketOffset_;
    BaseHashTable::MaskType hits_;
    //uint8_t indexInTags_ = kNotSet;
};

void
HashTable::allocateTables(uint64_t size) {
    AssertInfo(milvus::bits::isPowerOfTwo(size),
               "Size:{} for allocating tables must be a power of two",
               size);
    AssertInfo(size > 0,
               "Size:{} for allocating tables must be larger than zero",
               size);
    // Free existing table if present
    if (table_ != nullptr) {
        ::operator delete(table_, std::align_val_t(64));
        table_ = nullptr;
    }
    capacity_ = size;
    const uint64_t byteSize = capacity_ * tableSlotSize();
    AssertInfo(byteSize % kBucketSize == 0,
               "byteSize:{} for hashTable must be a multiple of kBucketSize:{}",
               byteSize,
               kBucketSize);
    numBuckets_ = byteSize / kBucketSize;
    sizeMask_ = byteSize - 1;
    sizeBits_ = __builtin_popcountll(sizeMask_);
    bucketOffsetMask_ = sizeMask_ & ~(kBucketSize - 1);
    // The total size is 8 bytes per slot, in groups of 16 slots with 16 bytes of
    // tags and 16 * 6 bytes of pointers and a padding of 16 bytes to round up the
    // cache line.
    // Allocate aligned memory (64-byte cache line alignment) for the table buffer.
    // TODO support memory pool here to avoid OOM
    table_ = static_cast<char*>(::operator new(byteSize, std::align_val_t(64)));
    std::memset(table_, 0, byteSize);
}

void
HashTable::checkSizeAndAllocateTable(int32_t numNew) {
    AssertInfo(capacity_ == 0 || capacity_ > numDistinct_,
               "capacity_ {}, numDistinct {}",
               capacity_,
               numDistinct_);
    if (table_ == nullptr || capacity_ == 0) {
        const auto newSize = newHashTableEntriesNumber(numDistinct_, numNew);
        allocateTables(newSize);
    }
}

bool
HashTable::compareKeys(const char* group,
                       milvus::exec::HashLookup& lookup,
                       milvus::vector_size_t row) {
    int32_t numKeys = lookup.hashers_.size();
    for (int32_t i = 0; i < numKeys; i++) {
        auto& hasher = lookup.hashers_[i];
        if (!rows_->equals(
                group, rows()->columnAt(i), hasher->columnData(), row)) {
            return false;
        }
    }
    return true;
}

void
HashTable::storeKeys(milvus::exec::HashLookup& lookup,
                     milvus::vector_size_t row) {
    for (int32_t i = 0; i < lookup.hashers_.size(); i++) {
        auto& hasher = lookup.hashers_[i];
        rows_->store(hasher->columnData(), row, lookup.hits_[row], i);
    }
}

void
HashTable::storeRowPointer(uint64_t index, uint64_t hash, char* row) {
    const int64_t bktOffset = bucketOffset(index);
    auto* bucket = bucketAt(bktOffset);
    const auto slotIndex = index & (sizeof(TagVector) - 1);
    bucket->setTag(slotIndex, hashTag(hash));
    bucket->setPointer(slotIndex, row);
}

char*
HashTable::insertEntry(milvus::exec::HashLookup& lookup,
                       uint64_t index,
                       milvus::vector_size_t row) {
    char* group = rows_->newRow();
    lookup.hits_[row] = group;
    storeKeys(lookup, row);
    storeRowPointer(index, lookup.hashes_[row], group);
    numDistinct_++;
    lookup.newGroups_.push_back(row);
    return group;
}

FOLLY_ALWAYS_INLINE void
HashTable::fullProbe(HashLookup& lookup, ProbeState& state) {
    constexpr ProbeState::Operation op = ProbeState::Operation::kInsert;
    lookup.hits_[state.row()] = state.fullProbe<op>(
        *this,
        [&](char* group, int32_t row) {
            return compareKeys(group, lookup, row);
        },
        [&](int32_t row, uint64_t index) {
            return insertEntry(lookup, index, row);
        });
}

void
HashTable::groupProbe(milvus::exec::HashLookup& lookup) {
    AssertInfo(hashMode_ == HashMode::kHash, "Only support kHash mode for now");
    checkSizeAndAllocateTable(0);
    ProbeState state;
    for (int32_t idx = 0; idx < lookup.hashes_.size(); idx++) {
        state.preProbe(*this, lookup.hashes_[idx], idx);
        state.firstProbe<ProbeState::Operation::kInsert>(*this);
        fullProbe(lookup, state);
    }
}

void
HashTable::setHashMode(HashMode mode, int32_t numNew) {
    // TODO set hash mode kArray/kHash/kNormalizedKey
}

void
HashTable::clear(bool freeTable) {
    if (table_) {
        ::operator delete(table_, std::align_val_t(64));
        table_ = nullptr;
    }
    numDistinct_ = 0;
    capacity_ = 0;
    numBuckets_ = 0;
    sizeMask_ = 0;
    bucketOffsetMask_ = 0;
}

}  // namespace exec
}  // namespace milvus
