// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <google/protobuf/map.h>
#include <google/protobuf/repeated_field.h>
#include <google/protobuf/timestamp.pb.h>
#include <google/protobuf/util/message_differencer.h>
#include <grpcpp/grpcpp.h>

#include <sstream>

#include "absl/container/flat_hash_map.h"
#include "ray/common/ray_config.h"
#include "ray/common/status.h"
#include "ray/util/logging.h"
#include "ray/util/type_traits.h"

namespace ray {

/// Wrap a protobuf message.
template <class Message>
// TODO(#55921): Remove MessageWrapper class and clean up LeaseSpec/TaskSpec classes
class MessageWrapper {
 public:
  /// Construct an empty message wrapper. This should not be used directly.
  MessageWrapper() : message_(std::make_shared<Message>()) {}

  /// Construct from a protobuf message object.
  /// The input message will be **copied** into this object.
  ///
  /// \param message The protobuf message.
  explicit MessageWrapper(Message message)
      : message_(std::make_shared<Message>(std::move(message))) {}

  /// Construct from a protobuf message shared_ptr.
  ///
  /// \param message The protobuf message.
  explicit MessageWrapper(std::shared_ptr<Message> message)
      : message_(std::move(message)) {}

  /// Construct from protobuf-serialized binary.
  ///
  /// \param serialized_binary Protobuf-serialized binary.
  explicit MessageWrapper(const std::string &serialized_binary)
      : message_(std::make_shared<Message>()) {
    RAY_CHECK(message_->ParseFromString(serialized_binary));
  }

  /// Get const reference of the protobuf message.
  const Message &GetMessage() const { return *message_; }

  /// Get reference of the protobuf message.
  Message &GetMutableMessage() { return *message_; }

  /// Serialize the message to a string.
  const std::string Serialize() const { return message_->SerializeAsString(); }

  bool operator==(const MessageWrapper<Message> &rhs) const {
    return google::protobuf::util::MessageDifferencer::Equivalent(GetMessage(),
                                                                  rhs.GetMessage());
  }

 protected:
  /// The wrapped message.
  std::shared_ptr<Message> message_;
};

/// Helper function that converts a ray status to gRPC status.
inline grpc::Status RayStatusToGrpcStatus(const Status &ray_status) {
  if (ray_status.ok()) {
    return grpc::Status::OK;
  }
  // Map Unauthenticated to gRPC's UNAUTHENTICATED status code
  if (ray_status.IsUnauthenticated()) {
    return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, ray_status.message());
  }
  // Unlike `UNKNOWN`, `ABORTED` is never generated by the library, so using it means
  // more robust.
  return grpc::Status(
      grpc::StatusCode::ABORTED, ray_status.CodeAsString(), ray_status.message());
}

inline std::string GrpcStatusToRayStatusMessage(const grpc::Status &s) {
  std::string msg;
  absl::StrAppend(&msg, "RPC error: ", s.error_message());
  if (s.error_details().size() > 0) {
    absl::StrAppend(&msg, ", details: ", s.error_details());
  }
  return msg;
}

/// Helper function that converts a gRPC status to ray status.
inline Status GrpcStatusToRayStatus(const grpc::Status &s) {
  if (s.ok()) {
    return Status::OK();
  }
  if (s.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
    // DEADLINE_EXCEEDED means the gRPC request has timedout. Convert it to Ray timeout
    // status code.
    return {StatusCode::TimedOut, GrpcStatusToRayStatusMessage(s)};
  }
  if (s.error_code() == grpc::StatusCode::UNAUTHENTICATED) {
    // UNAUTHENTICATED means authentication failed (e.g., wrong cluster ID).
    return Status::Unauthenticated(GrpcStatusToRayStatusMessage(s));
  }
  if (s.error_code() == grpc::StatusCode::ABORTED) {
    // This is a status generated by ray code.
    // See RayStatusToGrpcStatus for details.
    return {Status::StringToCode(s.error_message()), s.error_details()};
  }
  return Status::RpcError(GrpcStatusToRayStatusMessage(s), s.error_code());
}

/// Statuses that are retried infinitely by the GcsClient.
/// Now we only retry UNAVAILABLE and UNKNOWN statuses because that indicates the server
/// may be down.
inline bool IsGrpcRetryableStatus(Status status) {
  return status.IsRpcError() && (status.rpc_code() == grpc::StatusCode::UNAVAILABLE ||
                                 status.rpc_code() == grpc::StatusCode::UNKNOWN);
}

/// Converts a Protobuf `RepeatedPtrField` to a vector.
template <class T>
inline std::vector<T> VectorFromProtobuf(
    ::google::protobuf::RepeatedPtrField<T> pb_repeated) {
  return std::vector<T>(std::make_move_iterator(pb_repeated.begin()),
                        std::make_move_iterator(pb_repeated.end()));
}

/// Converts a Protobuf `RepeatedField` to a vector.
template <class T>
inline std::vector<T> VectorFromProtobuf(
    ::google::protobuf::RepeatedField<T> pb_repeated) {
  return std::vector<T>(std::make_move_iterator(pb_repeated.begin()),
                        std::make_move_iterator(pb_repeated.end()));
}

/// Converts a Protobuf `RepeatedField` to a vector of IDs.
template <class ID>
inline std::vector<ID> IdVectorFromProtobuf(
    const ::google::protobuf::RepeatedPtrField<::std::string> &pb_repeated) {
  auto str_vec = VectorFromProtobuf(pb_repeated);
  std::vector<ID> ret;
  std::transform(
      str_vec.begin(), str_vec.end(), std::back_inserter(ret), &ID::FromBinary);
  return ret;
}

/// Converts a Protobuf map to a cpp map.
template <class K, class V>
inline absl::flat_hash_map<K, V> MapFromProtobuf(
    const ::google::protobuf::Map<K, V> &pb_map) {
  return absl::flat_hash_map<K, V>(pb_map.begin(), pb_map.end());
}

/// Debug string for a google protobuf map.
template <class K, class V>
inline std::string DebugString(const ::google::protobuf::Map<K, V> &pb_map) {
  std::stringstream ss;
  ss << "{";
  bool first = true;
  for (const auto &pair : pb_map) {
    if (!first) {
      ss << ",";
      first = false;
    }
    ss << pair.first << ":" << pair.second;
  }
  ss << "}";
  return ss.str();
}

/// Check whether 2 google::protobuf::Map are equal. This function assumes that the
/// value of the map is either a simple type that supports operator== or a protobuf
/// message that can be compared using
/// google::protobuf::util::MessageDifferencer::Equivalent.
template <class K, class V>
bool MapEqual(const ::google::protobuf::Map<K, V> &lhs,
              const ::google::protobuf::Map<K, V> &rhs) {
  static_assert(
      has_equal_operator<V>::value ||
          std::is_base_of<google::protobuf::Message, V>::value,
      "Invalid value type for the map. The value of in the map must either be a simple "
      "type that supports operator== or a protobuf message that can be compared using "
      "google::protobuf::util::MessageDifferencer::Equivalent");

  if (lhs.size() != rhs.size()) {
    return false;
  }

  for (const auto &pair : lhs) {
    auto it = rhs.find(pair.first);
    if (it == rhs.end()) {
      return false;
    }
    if constexpr (has_equal_operator<V>::value) {
      if (it->second != pair.second) {
        return false;
      }
    } else if (std::is_base_of<google::protobuf::Message, V>::value) {
      if (!google::protobuf::util::MessageDifferencer::Equivalent(it->second,
                                                                  pair.second)) {
        return false;
      }
    } else {
      // Should never reach here due to the static_assert above.
      throw std::invalid_argument(
          "The value of in the map must either be a simple "
          "type that supports operator== or a protobuf message that can be compared "
          "using google::protobuf::util::MessageDifferencer::Equivalent");
    }
  }

  return true;
}

inline grpc::ChannelArguments CreateDefaultChannelArguments() {
  grpc::ChannelArguments arguments;
  if (::RayConfig::instance().grpc_client_keepalive_time_ms() > 0) {
    arguments.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS,
                     ::RayConfig::instance().grpc_client_keepalive_time_ms());
    arguments.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS,
                     ::RayConfig::instance().grpc_client_keepalive_timeout_ms());
    arguments.SetInt(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA, 0);
  }
  arguments.SetInt(GRPC_ARG_CLIENT_IDLE_TIMEOUT_MS,
                   ::RayConfig::instance().grpc_client_idle_timeout_ms());
  return arguments;
}

// Convert an epoch time in nanoseconds to a protobuf timestamp
// Ref: https://protobuf.dev/reference/php/api-docs/Google/Protobuf/Timestamp.html
inline google::protobuf::Timestamp AbslTimeNanosToProtoTimestamp(int64_t nanos) {
  google::protobuf::Timestamp timestamp;

  // Extract the seconds and the fractional nanoseconds from the epoch time
  timestamp.set_seconds(nanos / 1000000000);
  timestamp.set_nanos(nanos % 1000000000);
  return timestamp;
}

// Conver a protobuf timestamp to an epoch time in nanoseconds
// Ref: https://protobuf.dev/reference/php/api-docs/Google/Protobuf/Timestamp.html
inline int64_t ProtoTimestampToAbslTimeNanos(
    const google::protobuf::Timestamp &timestamp) {
  return timestamp.seconds() * 1000000000LL + timestamp.nanos();
}

}  // namespace ray
