// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <type_traits>
#include <utility>
#include <vector>
#include <memory>

#include <arrow/array.h>
#include <arrow/array/builder_primitive.h>
#include <fmt/core.h>

#include "FieldMeta.h"
#include "Types.h"

namespace milvus {

class Array {
 public:
    Array() = default;

    ~Array() = default;

    Array(char* data,
          int len,
          size_t size,
          DataType element_type,
          const uint32_t* offsets_ptr)
        : size_(size), length_(len), element_type_(element_type) {
        data_ = std::make_unique<char[]>(size);
        std::copy(data, data + size, data_.get());
        if (IsVariableDataType(element_type)) {
            AssertInfo(offsets_ptr != nullptr,
                       "For variable type elements in array, offsets_ptr must "
                       "be non-null");
            offsets_ptr_ = std::make_unique<uint32_t[]>(len);
            std::copy(offsets_ptr, offsets_ptr + len, offsets_ptr_.get());
        }
    }

    explicit Array(const ScalarFieldProto& field_data) {
        switch (field_data.data_case()) {
            case ScalarFieldProto::kBoolData: {
                element_type_ = DataType::BOOL;
                length_ = field_data.bool_data().data().size();
                size_ = length_;
                data_ = std::make_unique<char[]>(size_);
                for (int i = 0; i < length_; ++i) {
                    reinterpret_cast<bool*>(data_.get())[i] =
                        field_data.bool_data().data(i);
                }
                break;
            }
            case ScalarFieldProto::kIntData: {
                element_type_ = DataType::INT32;
                length_ = field_data.int_data().data().size();
                size_ = length_ * sizeof(int32_t);
                data_ = std::make_unique<char[]>(size_);
                for (int i = 0; i < length_; ++i) {
                    reinterpret_cast<int*>(data_.get())[i] =
                        field_data.int_data().data(i);
                }
                break;
            }
            case ScalarFieldProto::kLongData: {
                element_type_ = DataType::INT64;
                length_ = field_data.long_data().data().size();
                size_ = length_ * sizeof(int64_t);
                data_ = std::make_unique<char[]>(size_);
                for (int i = 0; i < length_; ++i) {
                    reinterpret_cast<int64_t*>(data_.get())[i] =
                        field_data.long_data().data(i);
                }
                break;
            }
            case ScalarFieldProto::kFloatData: {
                element_type_ = DataType::FLOAT;
                length_ = field_data.float_data().data().size();
                size_ = length_ * sizeof(float);
                data_ = std::make_unique<char[]>(size_);
                for (int i = 0; i < length_; ++i) {
                    reinterpret_cast<float*>(data_.get())[i] =
                        field_data.float_data().data(i);
                }
                break;
            }
            case ScalarFieldProto::kDoubleData: {
                element_type_ = DataType::DOUBLE;
                length_ = field_data.double_data().data().size();
                size_ = length_ * sizeof(double);
                data_ = std::make_unique<char[]>(size_);
                for (int i = 0; i < length_; ++i) {
                    reinterpret_cast<double*>(data_.get())[i] =
                        field_data.double_data().data(i);
                }
                break;
            }
            case ScalarFieldProto::kStringData: {
                element_type_ = DataType::STRING;
                length_ = field_data.string_data().data().size();
                offsets_ptr_ = std::make_unique<uint32_t[]>(length_);
                for (int i = 0; i < length_; ++i) {
                    offsets_ptr_[i] = size_;
                    size_ +=
                        field_data.string_data()
                            .data(i)
                            .size();  //type risk here between uint32_t vs size_t
                }
                data_ = std::make_unique<char[]>(size_);
                for (int i = 0; i < length_; ++i) {
                    std::copy_n(field_data.string_data().data(i).data(),
                                field_data.string_data().data(i).size(),
                                data_.get() + offsets_ptr_[i]);
                }
                break;
            }
            default: {
                // empty array
            }
        }
    }

    Array(const Array& array) noexcept
        : length_{array.length_},
          size_{array.size_},
          element_type_{array.element_type_} {
        data_ = std::make_unique<char[]>(array.size_);
        std::copy(
            array.data_.get(), array.data_.get() + array.size_, data_.get());
        if (IsVariableDataType(array.element_type_)) {
            AssertInfo(array.get_offsets_data() != nullptr,
                       "for array with variable length elements, offsets_ptr"
                       "must not be nullptr");
            offsets_ptr_ = std::make_unique<uint32_t[]>(length_);
            std::copy_n(
                array.get_offsets_data(), array.length(), offsets_ptr_.get());
        }
    }

    friend void
    swap(Array& array1, Array& array2) noexcept {
        using std::swap;
        swap(array1.data_, array2.data_);
        swap(array1.length_, array2.length_);
        swap(array1.size_, array2.size_);
        swap(array1.element_type_, array2.element_type_);
        swap(array1.offsets_ptr_, array2.offsets_ptr_);
    }

    Array&
    operator=(const Array& array) {
        Array temp(array);
        swap(*this, temp);
        return *this;
    }

    Array(Array&& other) noexcept : Array() {
        swap(*this, other);
    }

    Array&
    operator=(Array&& other) noexcept {
        swap(*this, other);
        return *this;
    }

    bool
    operator==(const Array& arr) const {
        if (element_type_ != arr.element_type_) {
            return false;
        }
        if (length_ != arr.length_) {
            return false;
        }
        if (length_ == 0) {
            return true;
        }
        switch (element_type_) {
            case DataType::INT64: {
                for (int i = 0; i < length_; ++i) {
                    if (get_data<int64_t>(i) != arr.get_data<int64_t>(i)) {
                        return false;
                    }
                }
                return true;
            }
            case DataType::BOOL: {
                for (int i = 0; i < length_; ++i) {
                    if (get_data<bool>(i) != arr.get_data<bool>(i)) {
                        return false;
                    }
                }
                return true;
            }
            case DataType::DOUBLE: {
                for (int i = 0; i < length_; ++i) {
                    if (get_data<double>(i) != arr.get_data<double>(i)) {
                        return false;
                    }
                }
                return true;
            }
            case DataType::FLOAT: {
                for (int i = 0; i < length_; ++i) {
                    if (get_data<float>(i) != arr.get_data<float>(i)) {
                        return false;
                    }
                }
                return true;
            }
            case DataType::INT32:
            case DataType::INT16:
            case DataType::INT8: {
                for (int i = 0; i < length_; ++i) {
                    if (get_data<int>(i) != arr.get_data<int>(i)) {
                        return false;
                    }
                }
                return true;
            }
            case DataType::STRING:
            case DataType::VARCHAR:
            //treat Geometry as wkb string
            case DataType::GEOMETRY: {
                for (int i = 0; i < length_; ++i) {
                    if (get_data<std::string_view>(i) !=
                        arr.get_data<std::string_view>(i)) {
                        return false;
                    }
                }
                return true;
            }
            default:
                ThrowInfo(Unsupported, "unsupported element type for array");
        }
    }

    template <typename T>
    T
    get_data(const int index) const {
        AssertInfo(index >= 0 && index < length_,
                   "index out of range, index={}, length={}",
                   index,
                   length_);
        if constexpr (std::is_same_v<T, std::string> ||
                      std::is_same_v<T, std::string_view>) {
            size_t element_length =
                (index == length_ - 1)
                    ? size_ - offsets_ptr_[length_ - 1]
                    : offsets_ptr_[index + 1] - offsets_ptr_[index];
            return T(data_.get() + offsets_ptr_[index], element_length);
        }
        if constexpr (std::is_same_v<T, int> || std::is_same_v<T, int64_t> ||
                      std::is_same_v<T, int8_t> || std::is_same_v<T, int16_t> ||
                      std::is_same_v<T, float> || std::is_same_v<T, double>) {
            switch (element_type_) {
                case DataType::INT8:
                case DataType::INT16:
                case DataType::INT32:
                    return static_cast<T>(
                        reinterpret_cast<int32_t*>(data_.get())[index]);
                case DataType::INT64:
                    return static_cast<T>(
                        reinterpret_cast<int64_t*>(data_.get())[index]);
                case DataType::FLOAT:
                    return static_cast<T>(
                        reinterpret_cast<float*>(data_.get())[index]);
                case DataType::DOUBLE:
                    return static_cast<T>(
                        reinterpret_cast<double*>(data_.get())[index]);
                default:
                    ThrowInfo(Unsupported,
                              "unsupported element type for array");
            }
        }
        return reinterpret_cast<T*>(data_.get())[index];
    }

    uint32_t*
    get_offsets_data() const {
        return offsets_ptr_.get();
    }

    void
    output_data(ScalarFieldProto& data_array) const {
        switch (element_type_) {
            case DataType::BOOL: {
                data_array.mutable_bool_data()->mutable_data()->Reserve(
                    length_);
                for (int j = 0; j < length_; ++j) {
                    auto element = get_data<bool>(j);
                    data_array.mutable_bool_data()->add_data(element);
                }
                break;
            }
            case DataType::INT8:
            case DataType::INT16:
            case DataType::INT32: {
                data_array.mutable_int_data()->mutable_data()->Reserve(length_);
                for (int j = 0; j < length_; ++j) {
                    auto element = get_data<int>(j);
                    data_array.mutable_int_data()->add_data(element);
                }
                break;
            }
            case DataType::INT64: {
                data_array.mutable_long_data()->mutable_data()->Reserve(
                    length_);
                for (int j = 0; j < length_; ++j) {
                    auto element = get_data<int64_t>(j);
                    data_array.mutable_long_data()->add_data(element);
                }
                break;
            }
            case DataType::STRING:
            case DataType::VARCHAR: {
                data_array.mutable_string_data()->mutable_data()->Reserve(
                    length_);
                for (int j = 0; j < length_; ++j) {
                    auto element = get_data<std::string_view>(j);
                    data_array.mutable_string_data()->add_data(element.data(),
                                                               element.size());
                }
                break;
            }
            case DataType::FLOAT: {
                data_array.mutable_float_data()->mutable_data()->Reserve(
                    length_);
                for (int j = 0; j < length_; ++j) {
                    auto element = get_data<float>(j);
                    data_array.mutable_float_data()->add_data(element);
                }
                break;
            }
            case DataType::DOUBLE: {
                data_array.mutable_double_data()->mutable_data()->Reserve(
                    length_);
                for (int j = 0; j < length_; ++j) {
                    auto element = get_data<double>(j);
                    data_array.mutable_double_data()->add_data(element);
                }
                break;
            }
            case DataType::GEOMETRY: {
                data_array.mutable_geometry_data()->mutable_data()->Reserve(
                    length_);
                for (int j = 0; j < length_; ++j) {
                    auto element = get_data<std::string_view>(j);
                    data_array.mutable_geometry_data()->add_data(
                        element.data(), element.size());
                }
                break;
            }
            default: {
                // empty array
            }
        }
    }

    ScalarFieldProto
    output_data() const {
        ScalarFieldProto data_array;
        output_data(data_array);
        return data_array;
    }

    int
    length() const {
        return length_;
    }

    size_t
    byte_size() const {
        return size_;
    }

    DataType
    get_element_type() const {
        return element_type_;
    }

    const char*
    data() const {
        return data_.get();
    }

    bool
    is_same_array(const proto::plan::Array& arr2) const {
        if (arr2.array_size() != length_) {
            return false;
        }
        if (length_ == 0) {
            return true;
        }
        if (!arr2.same_type()) {
            return false;
        }
        switch (element_type_) {
            case DataType::BOOL: {
                for (int i = 0; i < length_; i++) {
                    auto val = get_data<bool>(i);
                    if (val != arr2.array(i).bool_val()) {
                        return false;
                    }
                }
                return true;
            }
            case DataType::INT8:
            case DataType::INT16:
            case DataType::INT32: {
                for (int i = 0; i < length_; i++) {
                    auto val = get_data<int>(i);
                    if (val != arr2.array(i).int64_val()) {
                        return false;
                    }
                }
                return true;
            }
            case DataType::INT64: {
                for (int i = 0; i < length_; i++) {
                    auto val = get_data<int64_t>(i);
                    if (val != arr2.array(i).int64_val()) {
                        return false;
                    }
                }
                return true;
            }
            case DataType::FLOAT: {
                for (int i = 0; i < length_; i++) {
                    auto val = get_data<float>(i);
                    if (val != arr2.array(i).float_val()) {
                        return false;
                    }
                }
                return true;
            }
            case DataType::DOUBLE: {
                for (int i = 0; i < length_; i++) {
                    auto val = get_data<double>(i);
                    if (val != arr2.array(i).float_val()) {
                        return false;
                    }
                }
                return true;
            }
            case DataType::VARCHAR:
            case DataType::STRING:
            case DataType::GEOMETRY: {
                for (int i = 0; i < length_; i++) {
                    auto val = get_data<std::string>(i);
                    if (val != arr2.array(i).string_val()) {
                        return false;
                    }
                }
                return true;
            }
            default:
                return false;
        }
    }

 private:
    std::unique_ptr<char[]> data_{nullptr};
    int length_ = 0;
    int size_ = 0;
    DataType element_type_ = DataType::NONE;
    std::unique_ptr<uint32_t[]> offsets_ptr_{nullptr};
};

class ArrayView {
 public:
    ArrayView() = default;

    ArrayView(const ArrayView& other)
        : data_(other.data_),
          length_(other.length_),
          size_(other.size_),
          element_type_(other.element_type_),
          offsets_ptr_(other.offsets_ptr_) {
        AssertInfo(data_ != nullptr,
                   "data pointer for ArrayView cannot be nullptr");
        if (IsVariableDataType(element_type_)) {
            AssertInfo(offsets_ptr_ != nullptr,
                       "for array with variable length elements, offsets_ptr "
                       "must not be nullptr");
        }
    }

    ArrayView(char* data,
              int len,
              size_t size,
              DataType element_type,
              uint32_t* offsets_ptr)
        : data_(data),
          length_(len),
          size_(size),
          element_type_(element_type),
          offsets_ptr_(offsets_ptr) {
        AssertInfo(data != nullptr,
                   "data pointer for ArrayView cannot be nullptr");
        if (IsVariableDataType(element_type_)) {
            AssertInfo(offsets_ptr != nullptr,
                       "for array with variable length elements, offsets_ptr "
                       "must not be nullptr");
        }
    }

    template <typename T>
    T
    get_data(const int index) const {
        AssertInfo(index >= 0 && index < length_,
                   "index out of range, index={}, length={}",
                   index,
                   length_);

        if constexpr (std::is_same_v<T, std::string> ||
                      std::is_same_v<T, std::string_view>) {
            size_t element_length =
                (index == length_ - 1)
                    ? size_ - offsets_ptr_[length_ - 1]
                    : offsets_ptr_[index + 1] - offsets_ptr_[index];
            return T(data_ + offsets_ptr_[index], element_length);
        }
        if constexpr (std::is_same_v<T, int> || std::is_same_v<T, int64_t> ||
                      std::is_same_v<T, float> || std::is_same_v<T, double>) {
            switch (element_type_) {
                case DataType::INT8:
                case DataType::INT16:
                case DataType::INT32:
                    return static_cast<T>(
                        reinterpret_cast<int32_t*>(data_)[index]);
                case DataType::INT64:
                    return static_cast<T>(
                        reinterpret_cast<int64_t*>(data_)[index]);
                case DataType::FLOAT:
                    return static_cast<T>(
                        reinterpret_cast<float*>(data_)[index]);
                case DataType::DOUBLE:
                    return static_cast<T>(
                        reinterpret_cast<double*>(data_)[index]);
                default:
                    ThrowInfo(Unsupported,
                              "unsupported element type for array");
            }
        }
        return reinterpret_cast<T*>(data_)[index];
    }

    void
    output_data(ScalarFieldProto& data_array) const {
        switch (element_type_) {
            case DataType::BOOL: {
                data_array.mutable_bool_data()->mutable_data()->Reserve(
                    length_);
                for (int j = 0; j < length_; ++j) {
                    auto element = get_data<bool>(j);
                    data_array.mutable_bool_data()->add_data(element);
                }
                break;
            }
            case DataType::INT8:
            case DataType::INT16:
            case DataType::INT32: {
                data_array.mutable_int_data()->mutable_data()->Reserve(length_);
                for (int j = 0; j < length_; ++j) {
                    auto element = get_data<int>(j);
                    data_array.mutable_int_data()->add_data(element);
                }
                break;
            }
            case DataType::INT64: {
                data_array.mutable_long_data()->mutable_data()->Reserve(
                    length_);
                for (int j = 0; j < length_; ++j) {
                    auto element = get_data<int64_t>(j);
                    data_array.mutable_long_data()->add_data(element);
                }
                break;
            }
            case DataType::STRING:
            case DataType::VARCHAR: {
                data_array.mutable_string_data()->mutable_data()->Reserve(
                    length_);
                for (int j = 0; j < length_; ++j) {
                    auto element = get_data<std::string_view>(j);
                    data_array.mutable_string_data()->add_data(element.data(),
                                                               element.size());
                }
                break;
            }
            case DataType::FLOAT: {
                data_array.mutable_float_data()->mutable_data()->Reserve(
                    length_);
                for (int j = 0; j < length_; ++j) {
                    auto element = get_data<float>(j);
                    data_array.mutable_float_data()->add_data(element);
                }
                break;
            }
            case DataType::DOUBLE: {
                data_array.mutable_double_data()->mutable_data()->Reserve(
                    length_);
                for (int j = 0; j < length_; ++j) {
                    auto element = get_data<double>(j);
                    data_array.mutable_double_data()->add_data(element);
                }
                break;
            }
            case DataType::GEOMETRY: {
                data_array.mutable_geometry_data()->mutable_data()->Reserve(
                    length_);
                for (int j = 0; j < length_; ++j) {
                    auto element = get_data<std::string_view>(j);
                    data_array.mutable_geometry_data()->add_data(
                        element.data(), element.size());
                }
                break;
            }
            default: {
                // empty array
            }
        }
    }

    void
    output_data(Array& array) const {
        // Create a new Array object from ArrayView's data and assign it to the
        // output array
        array = Array(data_,
                      length_,
                      static_cast<size_t>(size_),
                      element_type_,
                      offsets_ptr_);
    }

    ScalarFieldProto
    output_data() const {
        ScalarFieldProto data_array;
        output_data(data_array);
        return data_array;
    }

    int
    length() const {
        return length_;
    }

    size_t
    byte_size() const {
        return size_;
    }

    DataType
    get_element_type() const {
        return element_type_;
    }

    const void*
    data() const {
        return data_;
    }

    bool
    is_same_array(const proto::plan::Array& arr2) const {
        if (arr2.array_size() != length_) {
            return false;
        }
        if (!arr2.same_type()) {
            return false;
        }
        switch (element_type_) {
            case DataType::BOOL: {
                for (int i = 0; i < length_; i++) {
                    auto val = get_data<bool>(i);
                    if (val != arr2.array(i).bool_val()) {
                        return false;
                    }
                }
                return true;
            }
            case DataType::INT8:
            case DataType::INT16:
            case DataType::INT32: {
                for (int i = 0; i < length_; i++) {
                    auto val = get_data<int>(i);
                    if (val != arr2.array(i).int64_val()) {
                        return false;
                    }
                }
                return true;
            }
            case DataType::INT64: {
                for (int i = 0; i < length_; i++) {
                    auto val = get_data<int64_t>(i);
                    if (val != arr2.array(i).int64_val()) {
                        return false;
                    }
                }
                return true;
            }
            case DataType::FLOAT: {
                for (int i = 0; i < length_; i++) {
                    auto val = get_data<float>(i);
                    if (val != arr2.array(i).float_val()) {
                        return false;
                    }
                }
                return true;
            }
            case DataType::DOUBLE: {
                for (int i = 0; i < length_; i++) {
                    auto val = get_data<double>(i);
                    if (val != arr2.array(i).float_val()) {
                        return false;
                    }
                }
                return true;
            }
            case DataType::VARCHAR:
            case DataType::STRING:
            case DataType::GEOMETRY: {
                for (int i = 0; i < length_; i++) {
                    auto val = get_data<std::string>(i);
                    if (val != arr2.array(i).string_val()) {
                        return false;
                    }
                }
                return true;
            }
            default:
                return length_ == 0;
        }
    }

 private:
    char* data_{nullptr};
    int length_ = 0;
    int size_ = 0;
    DataType element_type_ = DataType::NONE;

    //offsets ptr
    uint32_t* offsets_ptr_{nullptr};
};

}  // namespace milvus
