/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/sharding_util.h"

#include <optional>
#include <string>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
#include "xla/hlo/builder/xla_builder.h"
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_import.h"
#include "xla/shape.h"
#include "xla/tsl/platform/statusor.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/device_name_utils.h"

namespace tensorflow {
namespace {
const char kDeviceSuffixReplicatedCore[] = "REPLICATED_CORE";
const char kShardingAttribute[] = "_XlaSharding";
const char kShardingAttributeV2[] = "_XlaShardingV2";
const char kXlaShardingOp[] = "XlaSharding";
const char kShardingOpAttribute[] = "sharding";
}  // namespace

namespace {
xla::OpMetadata CreateOpMetadata(const std::string& op_type,
                                 const std::string& op_name) {
  xla::OpMetadata metadata;
  metadata.set_op_type(op_type);
  metadata.set_op_name(op_name);
  return metadata;
}

void AssignOpMetadataToSharding(xla::OpSharding& sharding,
                                const std::string& op_type,
                                const std::string& op_name) {
  auto metadata = CreateOpMetadata(op_type, op_name);
  if (sharding.type() == xla::OpSharding::TUPLE) {
    for (auto& sharding_element : *sharding.mutable_tuple_shardings()) {
      *sharding_element.add_metadata() = metadata;
    }
  } else {
    *sharding.add_metadata() = metadata;
  }
}

absl::Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
  return errors::InvalidArgument(
      "Invalid replicated core id: ", core,
      "; num_cores_per_replica=", num_cores_per_replica);
}
}  // namespace

absl::StatusOr<std::optional<xla::OpSharding>> ParseShardingFromDevice(
    const std::string& device_name, int num_cores_per_replica,
    std::optional<xla::OpSharding> explicit_sharding,
    std::optional<xla::OpMetadata> metadata) {
  if (device_name.empty()) {
    return explicit_sharding;
  }
  DeviceNameUtils::ParsedName parsed_device;
  if (!DeviceNameUtils::ParseFullName(device_name, &parsed_device)) {
    return errors::InvalidArgument("Malformed assigned device '", device_name,
                                   "'");
  }

  if (explicit_sharding.has_value()) {
    return explicit_sharding;
  } else if (!parsed_device.has_type || !parsed_device.has_id ||
             !absl::StrContains(parsed_device.type,
                                kDeviceSuffixReplicatedCore)) {
    return std::optional<xla::OpSharding>();
  } else {
    const int core = parsed_device.id;
    if (core < 0 || core >= num_cores_per_replica) {
      return CoreOutOfRangeError(core, num_cores_per_replica);
    }
    auto sharding = xla::sharding_builder::AssignDevice(core);
    if (metadata.has_value()) {
      *sharding.add_metadata() = metadata.value();
    }
    return std::optional<xla::OpSharding>(sharding);
  }
}

absl::StatusOr<std::optional<xla::OpSharding>> ParseShardingFromDevice(
    const NodeDef& node_def, int num_cores_per_replica, bool add_metadata) {
  const std::string& device_name = node_def.device();
  TF_ASSIGN_OR_RETURN(std::optional<xla::OpSharding> sharding,
                      GetShardingFromNodeDef(node_def, add_metadata));
  return ParseShardingFromDevice(
      device_name, num_cores_per_replica, sharding,
      add_metadata ? std::optional<xla::OpMetadata>(
                         CreateOpMetadata(node_def.op(), node_def.name()))
                   : std::nullopt);
}

absl::StatusOr<std::optional<xla::OpSharding>> ParseShardingFromDevice(
    const Node& node, int num_cores_per_replica, bool add_metadata) {
  std::string device_name = node.assigned_device_name();
  if (device_name.empty()) {
    device_name = node.requested_device();
  }
  TF_ASSIGN_OR_RETURN(std::optional<xla::OpSharding> sharding,
                      GetShardingFromNodeDef(node.def(), add_metadata));
  return ParseShardingFromDevice(
      device_name, num_cores_per_replica, sharding,
      add_metadata ? std::optional<xla::OpMetadata>(
                         CreateOpMetadata(node.type_string(), node.name()))
                   : std::nullopt);
}

absl::StatusOr<std::optional<xla::OpSharding>> ParseShardingFromEdgeSource(
    const Edge& edge, int num_cores_per_replica, bool add_metadata) {
  if (edge.src() == nullptr) {
    return tensorflow::errors::InvalidArgument(
        "Null src for ParseShardingFromEdgeSource edge=", edge.DebugString());
  }
  TF_ASSIGN_OR_RETURN(std::optional<xla::OpSharding> sharding,
                      ParseShardingFromDevice(
                          *edge.src(), num_cores_per_replica, add_metadata));
  if (sharding.has_value() &&
      sharding.value().type() == xla::OpSharding::TUPLE) {
    if (edge.src_output() < 0 ||
        edge.src_output() >= sharding.value().tuple_shardings_size()) {
      return tensorflow::errors::InvalidArgument(
          "Tuple index out of bound: edge=", edge.DebugString(),
          " sharding=", sharding->DebugString());
    }
    std::optional<xla::OpSharding> subsharding =
        sharding.value().tuple_shardings(edge.src_output());
    return subsharding;
  }
  return sharding;
}

void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) {
  std::string device_name = src.assigned_device_name();
  if (device_name.empty()) {
    device_name = src.requested_device();
  }
  dst->set_assigned_device_name(device_name);
  if (const AttrValue* attr = src.attrs().Find(kShardingAttribute)) {
    dst->AddAttr(kShardingAttribute, *attr);
  }
}

namespace {

absl::StatusOr<std::optional<xla::OpSharding>> GetShardingFromNodeDefInternal(
    const NodeDef& node_def, bool add_metadata, const char* attribute) {
  if (!HasNodeAttr(node_def, attribute)) {
    return std::optional<xla::OpSharding>();
  }
  std::string value;
  xla::OpSharding sharding;
  TF_RETURN_IF_ERROR(GetNodeAttr(node_def, attribute, &value));
  if (tensorflow::DecodeShardingAttribute(value, sharding).failed()) {
    return xla::InvalidArgument(
        "Experimental %s attribute was not a valid encoded xla::OpSharding "
        "proto.",
        attribute);
  }
  if (add_metadata) {
    AssignOpMetadataToSharding(sharding, node_def.op(), node_def.name());
  }
  return std::optional<xla::OpSharding>(sharding);
}

}  // namespace

absl::StatusOr<std::optional<xla::OpSharding>> GetShardingFromNodeDef(
    const NodeDef& node_def, bool add_metadata) {
  TF_ASSIGN_OR_RETURN(auto sharding_attribute,
                      GetShardingFromNodeDefInternal(node_def, add_metadata,
                                                     kShardingAttribute));

  // kShardingOpAttribute is only defined for 'XlaSharding' op
  xla::OpSharding primary_sharding;
  if (node_def.op() == kXlaShardingOp) {
    TF_ASSIGN_OR_RETURN(auto sharding_op_attribute,
                        GetShardingFromNodeDefInternal(node_def, add_metadata,
                                                       kShardingOpAttribute));
    if (!sharding_op_attribute.has_value()) {
      return sharding_attribute;
    }
    primary_sharding = sharding_op_attribute.value();
  } else {
    if (!sharding_attribute.has_value()) {
      return std::optional<xla::OpSharding>();
    }
    primary_sharding = sharding_attribute.value();
  }

  TF_ASSIGN_OR_RETURN(auto shardingv2,
                      GetShardingFromNodeDefInternal(node_def, add_metadata,
                                                     kShardingAttributeV2));

  if (!shardingv2.has_value()) {
    return primary_sharding;
  }

  if (tensorflow::VerifyShardingEquivalent(primary_sharding, shardingv2.value())
          .failed()) {
    return absl::InvalidArgumentError(absl::StrCat(
        "XlaSharding attribute was not equivalent to XlaShardingV2 "
        "attribute: ",
        primary_sharding.DebugString(), " vs ",
        shardingv2.value().DebugString()));
  }
  return shardingv2;
}

absl::Status addSdyShardingFrontendAttribute(xla::XlaBuilder* builder,
                                             xla::XlaOp op, xla::Shape shape,
                                             bool is_single_arg) {
  if (!builder->sharding().has_value()) {
    return absl::OkStatus();
  }

  return builder->SetInstructionFrontendAttribute(
      op, std::string(xla::HloSharding::kShardingFrontendAttrName),
      xla::sdy::convertToSdySharding(builder->sharding().value(), shape,
                                     /*openDims=*/false,
                                     /*inlineMesh=*/true, is_single_arg));
}

}  // namespace tensorflow
