/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/tsl/distributed_runtime/rpc/grpc_channel.h"

#include <cstdlib>
#include <limits>
#include <map>
#include <string>
#include <unordered_map>

#include "absl/status/status.h"
#include "absl/strings/escaping.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "absl/synchronization/mutex.h"
#include "grpcpp/create_channel.h"
#include "xla/tsl/distributed_runtime/rpc/grpc_channel_common.h"
#include "xla/tsl/lib/gtl/map_util.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/macros.h"
#include "xla/tsl/platform/status.h"
#include "xla/tsl/platform/types.h"
#include "xla/tsl/protobuf/rpc_options.pb.h"
#include "xla/tsl/util/device_name_utils.h"
#include "tsl/platform/numbers.h"
#include "tsl/platform/str_util.h"
#include "tsl/platform/strcat.h"
#include "tsl/platform/thread_annotations.h"

namespace tsl {

namespace {

std::string MakeAddress(const std::string& job, int replica, int task) {
  return strings::StrCat("/job:", job, "/replica:", replica, "/task:", task);
}

// Allows the host to be a raw IP (either v4 or v6).
absl::Status ValidateHostPortPair(const std::string& host_port) {
  std::string bns_prefix = "/bns/";
  if (host_port.substr(0, bns_prefix.length()) == bns_prefix) {
    return absl::OkStatus();
  }
  uint32_t port;
  auto colon_index = host_port.find_last_of(':');
  if (!absl::SimpleAtoi(host_port.substr(colon_index + 1), &port) ||
      host_port.substr(0, colon_index).find('/') != std::string::npos) {
    return absl::InvalidArgumentError(absl::StrCat(
        "Could not interpret \"", host_port, "\" as a host-port pair."));
  }
  return absl::OkStatus();
}

::grpc::ChannelArguments* CreateDefaultChannelArguments() {
  ::grpc::ChannelArguments* args = new ::grpc::ChannelArguments();
  const char* env = std::getenv("TF_GRPC_DEFAULT_OPTIONS");
  if (env != nullptr) {
    for (auto& grpc_option : absl::StrSplit(env, ',')) {
      std::vector<std::string> name_value = absl::StrSplit(grpc_option, '=');
      if (name_value.size() != 2) {
        LOG(ERROR) << "Invalid GRPC options format: " << grpc_option;
        continue;
      }
      VLOG(3) << "Setting GRPC default for '" << name_value[0] << "' to '"
              << name_value[1] << "'";
      if (name_value[1].size() >= 2 && name_value[1][0] == '"') {
        std::string ue_value =
            name_value[1].substr(1, name_value[1].size() - 2);
        std::string value;
        std::string error;
        if (!absl::CUnescape(ue_value, &value, &error)) {
          LOG(ERROR) << "Failed to parse escaped string for " << grpc_option
                     << ": " << error;
        } else {
          args->SetString(name_value[0], value);
        }
      } else {
        int64_t value;
        if (absl::SimpleAtoi(name_value[1], &value)) {
          args->SetInt(name_value[0], value);
        } else {
          LOG(ERROR) << "Invalid integer value: " << grpc_option;
        }
      }
    }
  }
  return args;
}

const ::grpc::ChannelArguments* GetDefaultChannelArguments() {
  static const ::grpc::ChannelArguments* args = CreateDefaultChannelArguments();
  return args;
}

}  // namespace

::grpc::ChannelArguments GetChannelArguments(const RPCOptions* rpc_options) {
  // TODO(mrry): Implement secure channels.
  ::grpc::ChannelArguments args = *GetDefaultChannelArguments();
  args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32_t>::max());
  // NOTE(mrry): Some versions of gRPC use a 20-second minimum backoff
  // on connection failure, which makes our tests time out.
  args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000);
  if (rpc_options != nullptr) {
    if (rpc_options->compression_algorithm() == "deflate") {
      args.SetCompressionAlgorithm(GRPC_COMPRESS_DEFLATE);
      args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL,
                  rpc_options->compression_level());
      VLOG(5) << "Setting GRPC compression : algo='"
              << rpc_options->compression_algorithm()
              << "' level=" << rpc_options->compression_level();
    } else if (rpc_options->compression_algorithm() == "gzip") {
      args.SetCompressionAlgorithm(GRPC_COMPRESS_GZIP);
      args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL,
                  rpc_options->compression_level());
      VLOG(5) << "Setting GRPC compression : algo='"
              << rpc_options->compression_algorithm()
              << "' level=" << rpc_options->compression_level();
    } else if (!rpc_options->compression_algorithm().empty()) {
      LOG(ERROR) << "Invalid compression algorithm: "
                 << rpc_options->compression_algorithm();
    }
    if (rpc_options->disable_session_connection_sharing()) {
      VLOG(5) << "Disabling TCP connection sharing";
      args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, true);
    }
  }
  return args;
}

absl::Status NewHostPortGrpcChannel(const std::string& target,
                                    const RPCOptions* rpc_options,
                                    SharedGrpcChannelPtr* channel_pointer) {
  // Minimally ensure that the target is valid
  TF_RETURN_IF_ERROR(ValidateHostPortPair(target));

  ::grpc::ChannelArguments args = GetChannelArguments(rpc_options);
  *channel_pointer = ::grpc::CreateCustomChannel(
      "dns:///" + target, ::grpc::InsecureChannelCredentials(), args);
  return absl::OkStatus();
}

ChannelCreationFunction ConvertToChannelCreationFunction(
    const std::function<absl::Status(std::string, const RPCOptions*,
                                     SharedGrpcChannelPtr*)>&
        new_channel_func_ptr) {
  return [new_channel_func_ptr](
             const std::string& target) -> SharedGrpcChannelPtr {
    SharedGrpcChannelPtr channel_ptr;
    if (new_channel_func_ptr(target, /*rpc_options=*/nullptr, &channel_ptr)
            .ok()) {
      return channel_ptr;
    } else {
      return nullptr;
    }
  };
}

absl::Status GrpcChannelSpec::AddHostPortsJob(
    const std::string& job_id, const std::map<int, std::string>& host_ports) {
  if (!job_ids_.insert(job_id).second) {
    return absl::InvalidArgumentError(
        absl::StrCat("Duplicate job ID in cluster specification: ", job_id));
  }
  for (const auto& id_host_port : host_ports) {
    TF_RETURN_IF_ERROR(ValidateHostPortPair(id_host_port.second));
  }
  host_ports_jobs_.emplace_back(job_id, host_ports);
  return absl::OkStatus();
}

namespace {

// GrpcChannelCache that caches results to FindWorkerChannel() calls.
using CachingGrpcChannelCache = GenericCachingChannelCache<GrpcChannelCache>;

// A ChannelCache that is the union of multiple ChannelCaches.
// Takes ownership of the caches passed to the constructor.
class MultiGrpcChannelCache : public CachingGrpcChannelCache {
 public:
  explicit MultiGrpcChannelCache(const std::vector<GrpcChannelCache*>& caches,
                                 int num_channels_per_target)
      : CachingGrpcChannelCache(num_channels_per_target), caches_(caches) {}

  ~MultiGrpcChannelCache() override {
    for (GrpcChannelCache* cache : caches_) {
      delete cache;
    }
  }

  void ListWorkers(std::vector<std::string>* workers) override {
    for (GrpcChannelCache* cache : caches_) {
      cache->ListWorkers(workers);
    }
  }

  void ListWorkersInJob(const std::string& job_name,
                        std::vector<std::string>* workers) override {
    for (GrpcChannelCache* cache : caches_) {
      cache->ListWorkersInJob(job_name, workers);
    }
  }

  std::string TranslateTask(const std::string& target) override {
    absl::MutexLock l(mu_);  // could use reader lock
    GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
    if (cache == nullptr) {
      for (GrpcChannelCache* c : caches_) {
        std::string r = c->TranslateTask(target);
        if (!r.empty()) {
          target_caches_.insert({target, c});
          cache = c;
          break;
        }
      }
    }
    CHECK(cache) << "Could not find GrpcChannelCache holding channel for "
                 << target;
    return cache->TranslateTask(target);
  }

 protected:
  SharedGrpcChannelPtr FindChannelOnce(const std::string& target) override {
    for (GrpcChannelCache* cache : caches_) {
      SharedGrpcChannelPtr ch(cache->FindWorkerChannel(target));
      if (ch) {
        absl::MutexLock l(mu_);
        target_caches_.insert({target, cache});
        return ch;
      }
    }
    return nullptr;
  }

 private:
  // List of channels used by this MultiGrpcChannelCache.
  const std::vector<GrpcChannelCache*> caches_;

  absl::Mutex mu_;
  // Cache of channels keyed by the target they are handling.
  // The same GrpcChannelCache can appear multiple times in the cache.
  std::unordered_map<std::string, GrpcChannelCache*> target_caches_
      TF_GUARDED_BY(mu_);
};

class SparseGrpcChannelCache : public CachingGrpcChannelCache {
 public:
  SparseGrpcChannelCache(const std::string& job_id,
                         const std::map<int, std::string>& host_ports,
                         ChannelCreationFunction channel_func,
                         int num_channels_per_target)
      : CachingGrpcChannelCache(num_channels_per_target),
        job_id_(job_id),
        host_ports_(host_ports),
        channel_func_(std::move(channel_func)) {
    VLOG(2) << "Initialize GrpcChannelCache for job " << ToString();
  }
  ~SparseGrpcChannelCache() override {}

  void ListWorkers(std::vector<std::string>* workers) override {
    workers->reserve(workers->size() + host_ports_.size());
    for (const auto& id_host_port : host_ports_) {
      std::vector<std::string> replicas =
          absl::StrSplit(id_host_port.second, ',', absl::SkipEmpty());
      for (int replica = 0; replica < replicas.size(); ++replica) {
        workers->emplace_back(
            MakeAddress(job_id_, replica, id_host_port.first));
      }
    }
  }

  void ListWorkersInJob(const std::string& job_name,
                        std::vector<std::string>* workers) override {
    if (job_name == job_id_) {
      ListWorkers(workers);
    }
  }

  std::string TranslateTask(const std::string& target) override {
    DeviceNameUtils::ParsedName parsed;
    if (!DeviceNameUtils::ParseFullName(target, &parsed)) {
      LOG(WARNING) << "Invalid target: " << target;
      return "";
    }

    if (!parsed.has_job || parsed.job != job_id_) {
      return "";
    }

    int32_t task = parsed.has_task ? parsed.task : -1;
    auto iter = host_ports_.find(task);
    if (iter == host_ports_.end()) {
      LOG(WARNING) << "Task " << task << " was not defined in sparse job "
                   << job_id_ << ": " << target;
      return "";
    }

    std::vector<std::string> host_ports =
        absl::StrSplit(iter->second, ',', absl::SkipEmpty());
    if (host_ports.size() > parsed.replica) {
      return host_ports[parsed.replica];
    }
    LOG(WARNING) << "Requested out-of-range replica, defaulting to 0: "
                 << target;
    return host_ports[0];
  }

 protected:
  SharedGrpcChannelPtr FindChannelOnce(const std::string& target) override {
    const std::string host_port = TranslateTask(target);
    if (host_port.empty()) {
      return nullptr;
    }
    auto chan_ptr = channel_func_(host_port);
    VLOG(5) << "Channel created for: job: " << job_id_
            << " host_port: " << host_port << " target : " << target
            << " Ptr: " << chan_ptr.get();
    return chan_ptr;
  }

 private:
  std::string ToString() {
    std::vector<std::string> task_strings;
    task_strings.reserve(host_ports_.size());
    for (const auto& id_host_port : host_ports_) {
      task_strings.emplace_back(
          absl::StrCat(id_host_port.first, " -> ", id_host_port.second));
    }
    return absl::StrCat(job_id_, " -> {", absl::StrJoin(task_strings, ", "),
                        "}");
  }

  const std::string job_id_;
  const std::map<int, std::string> host_ports_;
  const ChannelCreationFunction channel_func_;
  SparseGrpcChannelCache(const SparseGrpcChannelCache&) = delete;
  void operator=(const SparseGrpcChannelCache&) = delete;
};

}  // namespace

GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& spec,
                                      ChannelCreationFunction channel_func,
                                      const RPCOptions& options) {
  const int num_jobs = spec.host_ports_jobs().size();
  if (!num_jobs) {
    LOG(ERROR) << "Empty channel spec.";
    return nullptr;
  }
  std::vector<GrpcChannelCache*> caches;
  caches.reserve(num_jobs);
  for (auto& job : spec.host_ports_jobs()) {
    VLOG(2) << "Creating Grpc Channel Cache for: " << job.job_id;
    caches.push_back(
        new SparseGrpcChannelCache(job.job_id, job.host_ports, channel_func,
                                   options.num_channels_per_target()));
  }
  return caches.size() == 1 ? caches[0]
                            : new MultiGrpcChannelCache(
                                  caches, options.num_channels_per_target());
}

}  // namespace tsl
