/* Copyright 2017 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdint>
#include <sstream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/hash/hash.h"
#include "absl/status/status.h"
#include "absl/types/span.h"
#include "xla/array.h"
#include "xla/array3d.h"
#include "xla/array4d.h"
#include "xla/hlo/ir/mesh_and_axis.h"
#include "xla/hlo/ir/named_sharding.h"
#include "xla/hlo/ir/tile_assignment.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
#include "xla/hlo/testlib/test.h"
#include "xla/hlo/testlib/test_helpers.h"
#include "xla/shape.h"
#include "xla/shape_tree.h"
#include "xla/shape_util.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/util/proto/parse_text_proto.h"
#include "xla/tsl/util/proto/proto_matchers.h"
#include "xla/xla_data.pb.h"

namespace xla {
namespace {

using ::tsl::proto_testing::EqualsProto;
using ::tsl::proto_testing::ParseTextProtoOrDie;

Array<int64_t> MakeArray(absl::Span<const int64_t> dimensions,
                         absl::Span<const int64_t> contents) {
  Array<int64_t> a(dimensions);
  absl::c_copy(contents, a.begin());
  return a;
}

OpMetadata GetMetadata(const std::string& op_name) {
  OpMetadata metadata;
  metadata.set_op_name(op_name);
  return metadata;
}

std::vector<OpMetadata> SingleMetadata() { return {GetMetadata("a")}; }

std::vector<OpMetadata> ListMetadata() {
  return {GetMetadata("b"), GetMetadata("c")};
}

class HloShardingTest : public HloHardwareIndependentTestBase {};

// TODO(b/456418464): Parameterize `HloShardingTest` itself after supporting
// NamedSharding in all methods.
class HloShardingRepresentationTest
    : public HloShardingTest,
      public ::testing::WithParamInterface<bool> {};

TEST_P(HloShardingRepresentationTest, Replicate) {
  bool use_named_sharding = GetParam();
  HloSharding sharding = HloSharding::Replicate({}, use_named_sharding);
  EXPECT_EQ(sharding.UseNamedShardingLeaf(), use_named_sharding);
  EXPECT_TRUE(sharding.IsReplicated());
  EXPECT_TRUE(sharding.IsTileMaximal());
  EXPECT_TRUE(sharding.UsesDevice(0));
  EXPECT_TRUE(sharding.UsesDevice(65535));

  HloSharding other = HloSharding::Replicate({}, use_named_sharding);
  EXPECT_EQ(other, sharding);
  EXPECT_NE(HloSharding::Replicate(),
            HloSharding::Replicate({}, /*use_named_sharding=*/true));

  EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}),
                                 /*num_devices=*/2));
  EXPECT_FALSE(sharding.HasUniqueDevice());
}

TEST_P(HloShardingRepresentationTest, DevicePlacement) {
  bool use_named_sharding = GetParam();
  HloSharding sharding = HloSharding::AssignDevice(5, {}, use_named_sharding);
  EXPECT_EQ(sharding.UseNamedShardingLeaf(), use_named_sharding);
  EXPECT_FALSE(sharding.IsReplicated());
  EXPECT_TRUE(sharding.IsTileMaximal());
  EXPECT_FALSE(sharding.UsesDevice(0));
  EXPECT_TRUE(sharding.UsesDevice(5));
  EXPECT_EQ(5, sharding.GetUniqueDevice());

  HloSharding other = HloSharding::Replicate({}, use_named_sharding);
  EXPECT_NE(other, sharding);
  EXPECT_NE(HloSharding::AssignDevice(5),
            HloSharding::AssignDevice(5, {}, /*use_named_sharding=*/true));

  EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}),
                                 /*num_devices=*/6));
  EXPECT_IS_NOT_OK(
      sharding.Validate(ShapeUtil::MakeShape(U32, {4}), /*num_devices=*/5));

  ShapeTree<HloSharding> shape_tree =
      sharding.GetAsShapeTree(ShapeUtil::MakeShape(U32, {4}));
  EXPECT_EQ(shape_tree.element({}), sharding);
  EXPECT_TRUE(shape_tree.IsLeaf({}));
}

TEST_F(HloShardingTest, ProtoRoundTrip) {
  auto proto = ParseTextProtoOrDie<OpSharding>(R"pb(
    type: TUPLE
    tuple_shardings {
      type: OTHER
      tile_assignment_devices: 0
      tile_assignment_devices: 1
      tile_assignment_dimensions: 1
      tile_assignment_dimensions: 2
      metadata { op_name: "a" }
      metadata { op_name: "b" }
    }
    tuple_shardings {
      type: REPLICATED
      metadata { op_name: "c" }
    }
    tuple_shardings { type: MANUAL }
  )pb");
  HloSharding sharding = HloSharding::FromProto(proto).value();
  EXPECT_THAT(sharding.ToProto(), EqualsProto(proto));
}

TEST_F(HloShardingTest, IotaProtoRoundTrip) {
  auto proto = ParseTextProtoOrDie<OpSharding>(R"pb(
    type: TUPLE
    tuple_shardings {
      type: OTHER
      tile_assignment_dimensions: 6
      tile_assignment_dimensions: 1
      iota_reshape_dims: 3
      iota_reshape_dims: 2
      iota_transpose_perm: 1
      iota_transpose_perm: 0
      metadata { op_name: "a" }
      metadata { op_name: "b" }
    }
    tuple_shardings {
      type: REPLICATED
      metadata { op_name: "c" }
    }
    tuple_shardings { type: MANUAL }
  )pb");
  HloSharding sharding = HloSharding::FromProto(proto).value();
  EXPECT_THAT(sharding.ToProto(), EqualsProto(proto));
}

TEST_F(HloShardingTest, Tile) {
  {
    // Test should fail because of a duplicate tile assignment.
    HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 0, 2, 3}));
    EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {4, 6}),
                                       /*num_devices=*/4));
  }

  {
    // Test should fail because of more devices used than `num_device`.
    HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 1, 2, 3}));
    EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4, 6}),
                                       /*num_devices=*/2));
  }

  {
    // Test should fail because not all devices present in tile assignment.
    HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 1, 2, 3}));
    EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4, 6}),
                                       /*num_devices=*/5));
  }

  {
    // Test should pass.
    Shape shape = ShapeUtil::MakeShape(U32, {4, 5});
    HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1}));
    EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {3, 5}),
                                   /*num_devices=*/4));

    EXPECT_EQ(sharding.TileOffsetForDevice(shape, 0),
              (std::vector<int64_t>{0, 0}));
    EXPECT_EQ(sharding.TileOffsetForDevice(shape, 3),
              (std::vector<int64_t>{0, 3}));
    EXPECT_EQ(sharding.TileOffsetForDevice(shape, 2),
              (std::vector<int64_t>{2, 0}));
    EXPECT_EQ(sharding.TileOffsetForDevice(shape, 1),
              (std::vector<int64_t>{2, 3}));

    EXPECT_EQ(sharding.TileLimitForDevice(shape, 0),
              (std::vector<int64_t>{2, 3}));
    EXPECT_EQ(sharding.TileLimitForDevice(shape, 3),
              (std::vector<int64_t>{2, 5}));
    EXPECT_EQ(sharding.TileLimitForDevice(shape, 2),
              (std::vector<int64_t>{4, 3}));
    EXPECT_EQ(sharding.TileLimitForDevice(shape, 1),
              (std::vector<int64_t>{4, 5}));

    EXPECT_FALSE(sharding.HasUniqueDevice());

    // {device_index, tile_offest, tile_limit}.
    std::vector<std::tuple<int, std::vector<int64_t>, std::vector<int64_t>>>
        tiles;
    TF_ASSERT_OK(sharding.EachTile(
        shape.dimensions(),
        [&tiles](int device_index, absl::Span<const int64_t> tile_offset,
                 absl::Span<const int64_t> tile_limit) {
          std::vector<int64_t> offset(tile_offset.begin(), tile_offset.end());
          std::vector<int64_t> limit(tile_limit.begin(), tile_limit.end());
          tiles.emplace_back(device_index, std::move(offset), std::move(limit));
        }));
    EXPECT_THAT(tiles, ::testing::UnorderedElementsAre(
                           std::make_tuple(0, std::vector<int64_t>{0, 0},
                                           std::vector<int64_t>{2, 3}),
                           std::make_tuple(1, std::vector<int64_t>{2, 3},
                                           std::vector<int64_t>{4, 5}),
                           std::make_tuple(2, std::vector<int64_t>{2, 0},
                                           std::vector<int64_t>{4, 3}),
                           std::make_tuple(3, std::vector<int64_t>{0, 3},
                                           std::vector<int64_t>{2, 5})));
  }
}

TEST_F(HloShardingTest, EachTile) {
  auto validate = [](const Shape& shape,
                     const HloSharding& sharding) -> absl::Status {
    return sharding.EachTile(
        shape.dimensions(),
        [&shape, &sharding](int device_index,
                            absl::Span<const int64_t> tile_offset,
                            absl::Span<const int64_t> tile_limit) {
          EXPECT_EQ(tile_offset,
                    sharding.TileOffsetForDevice(shape, device_index));
          EXPECT_EQ(tile_limit,
                    sharding.TileLimitForDevice(shape, device_index));
        });
  };
  {
    // 6-way sharded along axis 0, 1-way sharded along axis 1.
    HloSharding sharding = HloSharding::Tile(TileAssignment({6, 1}));
    Shape shape = ShapeUtil::MakeShape(U32, {12, 20});
    TF_EXPECT_OK(validate(shape, sharding));
  }
  {
    // 6-way sharded along axis 0, 1-way sharded along axis 1.
    HloSharding sharding = HloSharding::Tile(TileAssignment({6, 1}));
    Shape shape = ShapeUtil::MakeShape(U32, {11, 20});
    TF_EXPECT_OK(validate(shape, sharding));
  }
  {
    // 2-way sharded along axis 0, 1-way sharded along axis 1, each shard
    // replicated by 3 times.
    HloSharding sharding = HloSharding::PartialTile(TileAssignment({2, 1, 3}));
    Shape shape = ShapeUtil::MakeShape(U32, {10, 20});
    TF_EXPECT_OK(validate(shape, sharding));
  }
  {
    // 2-way sharded along axis 0, 1-way sharded along axis 1, each shard
    // replicated by 3 times.
    HloSharding sharding = HloSharding::Subgroup(TileAssignment({2, 1, 3}),
                                                 {OpSharding::REPLICATED});
    Shape shape = ShapeUtil::MakeShape(U32, {10, 20});
    TF_EXPECT_OK(validate(shape, sharding));
  }
}

TEST_F(HloShardingTest, V1V2TileEquivalence) {
  {
    HloSharding v1 = HloSharding::Tile(MakeArray({2, 2}, {0, 1, 2, 3}));
    HloSharding v2 = HloSharding::IotaTile({2, 2});
    EXPECT_EQ(v1, v2);
    EXPECT_EQ(absl::HashOf(v1), absl::HashOf(v2));
  }
  {
    HloSharding v1 = HloSharding::Tile(MakeArray({2, 2}, {0, 2, 1, 3}));
    HloSharding v2 = HloSharding::IotaTile({2, 2}, {2, 2}, {1, 0});
    EXPECT_EQ(v1, v2);
    EXPECT_EQ(absl::HashOf(v1), absl::HashOf(v2));
  }
  {
    HloSharding v1 =
        HloSharding::Tile(MakeArray({2, 2, 2}, {0, 2, 4, 6, 1, 3, 5, 7}));
    HloSharding v2 = HloSharding::IotaTile({2, 2, 2}, {2, 2, 2}, {2, 0, 1});
    EXPECT_EQ(v1, v2);
    EXPECT_EQ(absl::HashOf(v1), absl::HashOf(v2));
  }
}

TEST_F(HloShardingTest, V1V2PartialTileEquivalence) {
  {
    HloSharding v1 = HloSharding::PartialTile(MakeArray({2, 2}, {0, 1, 2, 3}));
    HloSharding v2 = HloSharding::PartialTile(TileAssignment({2, 2}));
    EXPECT_EQ(v1, v2);
    EXPECT_EQ(absl::HashOf(v1), absl::HashOf(v2));
  }
  {
    HloSharding v1 = HloSharding::PartialTile(MakeArray({2, 2}, {0, 2, 1, 3}));
    HloSharding v2 =
        HloSharding::PartialTile(TileAssignment({2, 2}, {2, 2}, {1, 0}));
    EXPECT_EQ(v1, v2);
    EXPECT_EQ(absl::HashOf(v1), absl::HashOf(v2));
  }
  {
    HloSharding v1 = HloSharding::PartialTile(
        MakeArray({2, 2, 2}, {0, 2, 4, 6, 1, 3, 5, 7}));
    HloSharding v2 = HloSharding::PartialTile(
        TileAssignment({2, 2, 2}, {2, 2, 2}, {2, 0, 1}));
    EXPECT_EQ(v1, v2);
    EXPECT_EQ(absl::HashOf(v1), absl::HashOf(v2));
  }
}

TEST_F(HloShardingTest, V1V2SubgroupEquivalence) {
  {
    HloSharding v1 =
        HloSharding::Subgroup(MakeArray({2, 2}, {0, 1, 2, 3}),
                              {OpSharding::MANUAL, OpSharding::REPLICATED});
    HloSharding v2 = HloSharding::Subgroup(
        TileAssignment({2, 2}), {OpSharding::MANUAL, OpSharding::REPLICATED});
    EXPECT_EQ(v1, v2);
    EXPECT_EQ(absl::HashOf(v1), absl::HashOf(v2));
  }
  {
    HloSharding v1 =
        HloSharding::Subgroup(MakeArray({2, 2}, {0, 2, 1, 3}),
                              {OpSharding::MANUAL, OpSharding::REPLICATED});
    HloSharding v2 =
        HloSharding::Subgroup(TileAssignment({2, 2}, {2, 2}, {1, 0}),
                              {OpSharding::MANUAL, OpSharding::REPLICATED});
    EXPECT_EQ(v1, v2);
    EXPECT_EQ(absl::HashOf(v1), absl::HashOf(v2));
  }
  {
    HloSharding v1 =
        HloSharding::Subgroup(MakeArray({2, 2, 2}, {0, 2, 4, 6, 1, 3, 5, 7}),
                              {OpSharding::MANUAL, OpSharding::REPLICATED});
    HloSharding v2 =
        HloSharding::Subgroup(TileAssignment({2, 2, 2}, {2, 2, 2}, {2, 0, 1}),
                              {OpSharding::MANUAL, OpSharding::REPLICATED});
    EXPECT_EQ(v1, v2);
    EXPECT_EQ(absl::HashOf(v1), absl::HashOf(v2));
  }
}

// Tests that empty tuple is supported.
TEST_P(HloShardingRepresentationTest, EmptySingleTuple) {
  bool use_named_sharding = GetParam();
  HloSharding sharding = HloSharding::SingleTuple(
      ShapeUtil::MakeTupleShape({}),
      HloSharding::AssignDevice(0, {}, use_named_sharding));
  EXPECT_TRUE(sharding.ExtractSingleSharding());
  EXPECT_EQ(sharding.ExtractSingleSharding()->UseNamedShardingLeaf(),
            use_named_sharding);
}

// Tests that empty tuple is not a shard group.
TEST_P(HloShardingRepresentationTest, EmptySingleTupleIsNotShardGroup) {
  bool use_named_sharding = GetParam();
  HloSharding sharding = HloSharding::SingleTuple(
      ShapeUtil::MakeTupleShape({}),
      HloSharding::AssignDevice(0, {}, use_named_sharding));
  EXPECT_FALSE(sharding.IsShardGroup());
  EXPECT_FALSE(sharding.IsShardAs());
  EXPECT_FALSE(sharding.IsShardLike());
}

INSTANTIATE_TEST_SUITE_P(HloShardingRepresentationTest,
                         HloShardingRepresentationTest,
                         ::testing::Values(false, true));

TEST_F(HloShardingTest, NestedTuple) {
  // nested_tuple_shape = (f32[], (f32[3]), f32[4, 6])
  Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({
      ShapeUtil::MakeShape(F32, {}),
      ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3})}),
      ShapeUtil::MakeShape(F32, {4, 6}),
  });

  HloSharding tiled_sharding = HloSharding::Tile(Array<int64_t>({{0, 1}}));
  OpSharding proto;
  proto.set_type(OpSharding::TUPLE);
  *proto.add_tuple_shardings() = HloSharding::Replicate().ToProto();
  *proto.add_tuple_shardings() = HloSharding::AssignDevice(0).ToProto();
  *proto.add_tuple_shardings() = tiled_sharding.ToProto();
  HloSharding tuple_sharding = HloSharding::FromProto(proto).value();

  ShapeTree<HloSharding> shape_tree =
      tuple_sharding.GetAsShapeTree(nested_tuple_shape);
  EXPECT_EQ(shape_tree.element({0}), HloSharding::Replicate());
  EXPECT_EQ(shape_tree.element({1, 0}), HloSharding::AssignDevice(0));
  EXPECT_EQ(shape_tree.element({2}), tiled_sharding);

  EXPECT_IS_OK(tuple_sharding.Validate(nested_tuple_shape, /*num_devices=*/2));
  // Test should fail because tuple element count does not match.
  EXPECT_IS_NOT_OK(tuple_sharding.Validate(ShapeUtil::MakeTupleShape({}),
                                           /*num_devices=*/5));
  // Test should fail because the input type is not a tuple.
  EXPECT_IS_NOT_OK(tuple_sharding.Validate(ShapeUtil::MakeShape(F32, {}),
                                           /*num_devices=*/5));
}

TEST_F(HloShardingTest, NormalizeTrivialSubgroupToManual) {
  HloSharding sharding =
      HloSharding::Subgroup(MakeArray({1, 2, 1}, {0, 1}),
                            {OpSharding::MANUAL, OpSharding::REPLICATED});
  EXPECT_TRUE(sharding.IsManual());
}

TEST_F(HloShardingTest, Hash) {
  auto hash_compare_equal = [](const HloSharding& a, const HloSharding& b) {
    if (absl::HashOf(a) != absl::HashOf(b)) {
      return false;
    }
    return a == b;
  };

  {
    HloSharding sharding1 = HloSharding::Replicate();
    HloSharding sharding2 = HloSharding::Replicate();
    EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
  }

  {
    HloSharding sharding1 = HloSharding::AssignDevice(1);
    HloSharding sharding2 = HloSharding::AssignDevice(1);
    EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
  }

  {
    HloSharding sharding1 = HloSharding::AssignDevice(1);
    HloSharding sharding2 = HloSharding::AssignDevice(2);
    EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
  }

  {
    HloSharding sharding1 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1}));
    HloSharding sharding2 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1}));
    EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
  }

  {
    HloSharding sharding1 = HloSharding::IotaTile({3, 4});
    HloSharding sharding2 = HloSharding::Tile(
        MakeArray({3, 4}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}));
    EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
  }

  HloSharding default_sharding = HloSharding::Replicate();
  {
    ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}),
                                      default_sharding);
    HloSharding sharding1 = HloSharding::Replicate();
    HloSharding sharding2 = HloSharding::Tuple(shape_tree);
    EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
  }

  {
    ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}),
                                      default_sharding);
    HloSharding sharding1 = HloSharding::Tuple(shape_tree);
    HloSharding sharding2 = HloSharding::Tuple(shape_tree);
    EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
  }

  {
    ShapeTree<HloSharding> shape_tree1(
        ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
        default_sharding);
    *shape_tree1.mutable_element({0}) = HloSharding::Replicate();
    ShapeTree<HloSharding> shape_tree2(
        ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
        default_sharding);
    *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0);
    HloSharding sharding1 = HloSharding::Tuple(shape_tree1);
    HloSharding sharding2 = HloSharding::Tuple(shape_tree2);
    EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
  }

  {
    ShapeTree<HloSharding> shape_tree1(
        ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
        default_sharding);
    *shape_tree1.mutable_element({0}) = HloSharding::AssignDevice(0);
    ShapeTree<HloSharding> shape_tree2(
        ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
        default_sharding);
    *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0);
    HloSharding sharding1 = HloSharding::Tuple(shape_tree1);
    HloSharding sharding2 = HloSharding::Tuple(shape_tree2);
    EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
  }
}

using ShardingWithMetadataParamType =
    std::tuple<std::vector<OpMetadata>, std::string>;

TEST_P(HloShardingRepresentationTest, ToStringReplicatedTest) {
  bool use_named_sharding = GetParam();
  HloSharding sharding = HloSharding::Replicate({}, use_named_sharding);
  EXPECT_EQ(sharding.ToString(), "{replicated}");
}

class HloReplicateShardingWithMetadataTest
    : public ::testing::TestWithParam<ShardingWithMetadataParamType> {};

TEST_P(HloReplicateShardingWithMetadataTest, ToStringTest) {
  HloSharding sharding = HloSharding::Replicate(std::get<0>(GetParam()));
  EXPECT_EQ(sharding.ToString(/*include_metadata=*/false), "{replicated}");
  EXPECT_EQ(sharding.ToString(/*include_metadata=*/true),
            std::get<1>(GetParam()));
}

INSTANTIATE_TEST_SUITE_P(
    ToString, HloReplicateShardingWithMetadataTest,
    ::testing::Values(
        std::make_tuple(std::vector<OpMetadata>(), "{replicated}"),
        std::make_tuple(SingleMetadata(),
                        "{replicated metadata={op_name=\"a\"}}"),
        std::make_tuple(
            ListMetadata(),
            "{replicated metadata={{op_name=\"b\"}, {op_name=\"c\"}}}")));

TEST_P(HloShardingRepresentationTest, ToStringAssignDeviceTest) {
  bool use_named_sharding = GetParam();
  HloSharding sharding = HloSharding::AssignDevice(7, {}, use_named_sharding);
  EXPECT_EQ(sharding.ToString(), "{maximal device=7}");
}

class HloAssignDeviceShardingWithMetadataTest
    : public ::testing::TestWithParam<ShardingWithMetadataParamType> {};

TEST_P(HloAssignDeviceShardingWithMetadataTest, ToStringTest) {
  HloSharding sharding = HloSharding::AssignDevice(7, std::get<0>(GetParam()));
  EXPECT_EQ(sharding.ToString(/*include_metadata=*/false),
            "{maximal device=7}");
  EXPECT_EQ(sharding.ToString(/*include_metadata=*/true),
            std::get<1>(GetParam()));
}

INSTANTIATE_TEST_SUITE_P(
    ToString, HloAssignDeviceShardingWithMetadataTest,
    ::testing::Values(
        std::make_tuple(std::vector<OpMetadata>(), "{maximal device=7}"),
        std::make_tuple(SingleMetadata(),
                        "{maximal device=7 metadata={op_name=\"a\"}}"),
        std::make_tuple(
            ListMetadata(),
            "{maximal device=7 metadata={{op_name=\"b\"}, {op_name=\"c\"}}}")));

TEST_F(HloShardingTest, ToStringTiledTest) {
  HloSharding sharding =
      HloSharding::Tile(Array3D<int64_t>({{{2, 3}}, {{5, 7}}}));
  EXPECT_EQ(sharding.ToString(), "{devices=[2,1,2]2,3,5,7}");
}

TEST_F(HloShardingTest, ToStringIotaTiledTest) {
  HloSharding sharding = HloSharding::IotaTile({3, 4}, {2, 2, 3}, {2, 1, 0});
  EXPECT_EQ(sharding.ToString(), "{devices=[3,4]<=[2,2,3]T(2,1,0)}");
}

class HloTiledShardingWithMetadataTest
    : public ::testing::TestWithParam<ShardingWithMetadataParamType> {};

TEST_P(HloTiledShardingWithMetadataTest, ToStringTest) {
  HloSharding sharding = HloSharding::Tile(
      Array3D<int64_t>({{{2, 3}}, {{5, 7}}}), std::get<0>(GetParam()));
  EXPECT_EQ(sharding.ToString(/*include_metadata=*/false),
            "{devices=[2,1,2]2,3,5,7}");
  EXPECT_EQ(sharding.ToString(/*include_metadata=*/true),
            std::get<1>(GetParam()));
}

INSTANTIATE_TEST_SUITE_P(
    ToString, HloTiledShardingWithMetadataTest,
    ::testing::Values(
        std::make_tuple(std::vector<OpMetadata>(), "{devices=[2,1,2]2,3,5,7}"),
        std::make_tuple(SingleMetadata(),
                        "{devices=[2,1,2]2,3,5,7 metadata={op_name=\"a\"}}"),
        std::make_tuple(ListMetadata(),
                        "{devices=[2,1,2]2,3,5,7 metadata={{op_name=\"b\"}, "
                        "{op_name=\"c\"}}}")));

TEST_F(HloShardingTest, ToStringTupleTest) {
  HloSharding sharding = HloSharding::Tuple(
      ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}),
                                 ShapeUtil::MakeShape(U32, {7, 25}),
                                 ShapeUtil::MakeShape(S32, {9, 11})}),
      {HloSharding::Replicate(), HloSharding::Tile(Array2D<int64_t>({{3, 5}})),
       HloSharding::AssignDevice(3)});
  EXPECT_EQ(sharding.ToString(),
            "{{replicated}, {devices=[1,2]3,5}, {maximal device=3}}");
}

TEST_F(HloShardingTest, ToStringTupleWithMetadataTest) {
  auto metadata = SingleMetadata();
  HloSharding sharding = HloSharding::Tuple(
      ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}),
                                 ShapeUtil::MakeShape(U32, {7, 25}),
                                 ShapeUtil::MakeShape(S32, {9, 11})}),
      {HloSharding::Replicate({GetMetadata("d")}),
       HloSharding::Tile(Array2D<int64_t>({{3, 5}})),
       HloSharding::AssignDevice(3, {GetMetadata("e")})});
  EXPECT_EQ(sharding.ToString(/*include_metadata=*/false),
            "{{replicated}, {devices=[1,2]3,5}, {maximal device=3}}");
  EXPECT_EQ(sharding.ToString(/*include_metadata=*/true),
            "{{replicated metadata={op_name=\"d\"}}, {devices=[1,2]3,5}, "
            "{maximal device=3 metadata={op_name=\"e\"}}}");
}

TEST_F(HloShardingTest, ToStringWithNamedShardingTest) {
  Mesh mesh({2, 4}, {"a", "b"});
  NamedSharding::DimensionSharding ds_a({AxisRef(0)},
                                        /*is_closed=*/true);
  NamedSharding::DimensionSharding ds_b({AxisRef(1)},
                                        /*is_closed=*/true);
  HloSharding sharding(NamedSharding(mesh, {{ds_a}, {ds_b}}));
  EXPECT_EQ(sharding.ToString(), "{@mesh<a=2,b=4>, [{a}, {b}]}");

  HloSharding sharding_with_metadata(
      NamedSharding(mesh, {{ds_a}, {ds_b}}, {}, {}, {}, ListMetadata()));
  EXPECT_EQ(sharding_with_metadata.ToString(/*include_metadata=*/true),
            "{@mesh<a=2,b=4>, [{a}, {b}], metadata={{op_name=\"b\"}, "
            "{op_name=\"c\"}}}");

  HloSharding tuple_sharding(HloSharding::Tuple(
      ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}),
                                 ShapeUtil::MakeShape(U32, {7, 25}),
                                 ShapeUtil::MakeShape(S32, {9, 11})}),
      {sharding, sharding, sharding_with_metadata}));
  EXPECT_EQ(tuple_sharding.ToString(/*include_metadata=*/true),
            "{{@mesh<a=2,b=4>, [{a}, {b}]}, {@mesh<a=2,b=4>, [{a}, {b}]}, "
            "{@mesh<a=2,b=4>, [{a}, {b}], metadata={{op_name=\"b\"}, "
            "{op_name=\"c\"}}}}");
}

TEST_F(HloShardingTest, OstreamTest) {
  HloSharding sharding =
      HloSharding::Tile(Array4D<int64_t>({{{{0, 1}, {2, 3}}}}));
  std::ostringstream oss;
  oss << sharding;
  EXPECT_EQ(oss.str(), "{devices=[1,1,2,2]0,1,2,3}");
}

class HloParseShardingWithMetadataTest
    : public ::testing::TestWithParam<std::vector<OpMetadata>> {};

TEST_P(HloParseShardingWithMetadataTest, ParseHloString) {
  auto check = [](const HloSharding& sharding) {
    TF_ASSERT_OK_AND_ASSIGN(
        auto parsed_sharding,
        ParseSharding(sharding.ToString(/*include_metadata=*/true)));
    EXPECT_EQ(sharding, parsed_sharding);
  };
  check(HloSharding::Replicate(GetParam()));
  check(HloSharding::AssignDevice(2, GetParam()));
  check(HloSharding::Tile(Array4D<int64_t>({{{{0}, {1}}}}), GetParam()));
  // Empty tuple. One sharding is required for empty tuples, as we need to be
  // able to assign sharding to them, even though they have no leaves.
  check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}),
                           {HloSharding::Replicate(GetParam())}));
  {
    // Non-nested tuple.
    auto tuple_shape =
        ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 1, 5, 7}),
                                   ShapeUtil::MakeShape(F32, {3, 5, 7}),
                                   ShapeUtil::MakeShape(F32, {3, 7})});
    check(HloSharding::Tuple(
        tuple_shape,
        {HloSharding::Tile(Array4D<int64_t>({{{{0}, {1}}}})),
         HloSharding::Replicate(GetParam()), HloSharding::AssignDevice(1)}));
  }
  {
    // Nested tuple.
    auto tuple_shape = ShapeUtil::MakeTupleShape(
        {ShapeUtil::MakeShape(F32, {3, 1, 5, 7}),
         ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5, 7}),
                                    ShapeUtil::MakeShape(F32, {3, 7})})});
    std::vector<HloSharding> leaf_shardings = {
        HloSharding::Tile(Array4D<int64_t>({{{{0}, {1}}}})),
        HloSharding::Replicate(), HloSharding::AssignDevice(1, GetParam())};
    ShapeTree<HloSharding> sharding_tree(tuple_shape, HloSharding::Replicate());
    // Assign leaf_shardings to sharding_tree leaves.
    auto it = leaf_shardings.begin();
    for (auto& index_to_sharding : sharding_tree.leaves()) {
      index_to_sharding.second = *it++;
    }
    check(HloSharding::Tuple(sharding_tree));
  }
}

INSTANTIATE_TEST_SUITE_P(ParseHloString, HloParseShardingWithMetadataTest,
                         ::testing::Values(std::vector<OpMetadata>(),
                                           SingleMetadata(), ListMetadata()));

TEST_F(HloShardingTest, WithMetadataNoOverwrite) {
  {
    HloSharding sharding = HloSharding::Replicate();
    auto sharding_new_metadata =
        sharding.WithMetadata(SingleMetadata(), /*overwrite=*/false);
    ASSERT_EQ(sharding_new_metadata.metadata().size(), 1);
    EXPECT_THAT(sharding_new_metadata.metadata().front(),
                EqualsProto(SingleMetadata().front()));
  }

  {
    HloSharding sharding = HloSharding::AssignDevice(7, SingleMetadata());
    auto sharding_new_metadata =
        sharding.WithMetadata(ListMetadata(), /*overwrite=*/false);
    ASSERT_EQ(sharding_new_metadata.metadata().size(), 1);
    EXPECT_THAT(sharding_new_metadata.metadata().front(),
                EqualsProto(sharding.metadata().front()));
  }

  {
    HloSharding sharding = HloSharding::Tuple(
        ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}),
                                   ShapeUtil::MakeShape(U32, {7, 25}),
                                   ShapeUtil::MakeShape(S32, {9, 11})}),
        {HloSharding::Replicate(SingleMetadata()),
         HloSharding::Tile(Array2D<int64_t>({{3, 5}})),
         HloSharding::AssignDevice(3, SingleMetadata())});
    auto sharding_new_metadata =
        sharding.WithMetadata(ListMetadata(), /*overwrite=*/false);
    EXPECT_TRUE(sharding_new_metadata.metadata().empty());
    ASSERT_TRUE(sharding_new_metadata.IsTuple());
    ASSERT_EQ(sharding_new_metadata.tuple_elements().size(), 3);

    ASSERT_EQ(sharding_new_metadata.tuple_elements()[0].metadata().size(), 1);
    EXPECT_THAT(sharding_new_metadata.tuple_elements()[0].metadata().front(),
                EqualsProto(SingleMetadata().front()));

    ASSERT_EQ(sharding_new_metadata.tuple_elements()[1].metadata().size(), 2);
    for (int i = 0; i < 2; ++i) {
      EXPECT_THAT(sharding_new_metadata.tuple_elements()[1].metadata()[i],
                  EqualsProto(ListMetadata()[i]));
    }

    ASSERT_EQ(sharding_new_metadata.tuple_elements()[2].metadata().size(), 1);
    EXPECT_THAT(sharding_new_metadata.tuple_elements()[2].metadata().front(),
                EqualsProto(SingleMetadata().front()));
  }
}

TEST_F(HloShardingTest, WithMetadataOverwrite) {
  {
    HloSharding sharding = HloSharding::Replicate();
    auto sharding_new_metadata =
        sharding.WithMetadata(SingleMetadata(), /*overwrite=*/true);
    ASSERT_EQ(sharding_new_metadata.metadata().size(), 1);
    EXPECT_THAT(sharding_new_metadata.metadata().front(),
                EqualsProto(SingleMetadata().front()));
  }

  {
    HloSharding sharding = HloSharding::AssignDevice(7, SingleMetadata());
    auto sharding_new_metadata =
        sharding.WithMetadata(ListMetadata(), /*overwrite=*/true);
    ASSERT_EQ(sharding_new_metadata.metadata().size(), 2);
    for (int i = 0; i < 2; ++i) {
      EXPECT_THAT(sharding_new_metadata.metadata()[i],
                  EqualsProto(ListMetadata()[i]));
    }
  }

  {
    HloSharding sharding = HloSharding::Tuple(
        ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}),
                                   ShapeUtil::MakeShape(U32, {7, 25}),
                                   ShapeUtil::MakeShape(S32, {9, 11})}),
        {HloSharding::Replicate(SingleMetadata()),
         HloSharding::Tile(Array2D<int64_t>({{3, 5}})),
         HloSharding::AssignDevice(3, SingleMetadata())});
    auto sharding_new_metadata =
        sharding.WithMetadata(ListMetadata(), /*overwrite=*/true);
    EXPECT_TRUE(sharding_new_metadata.metadata().empty());
    ASSERT_TRUE(sharding_new_metadata.IsTuple());
    ASSERT_EQ(sharding_new_metadata.tuple_elements().size(), 3);

    for (const auto& sub_sharding : sharding_new_metadata.tuple_elements()) {
      ASSERT_EQ(sub_sharding.metadata().size(), 2);
      for (int i = 0; i < 2; ++i) {
        EXPECT_THAT(sub_sharding.metadata()[i], EqualsProto(ListMetadata()[i]));
      }
    }
  }
}

TEST_F(HloShardingTest, WithoutMetadata) {
  {
    HloSharding sharding = HloSharding::Replicate();
    auto sharding_no_metadata = sharding.WithoutMetadata();
    EXPECT_TRUE(sharding_no_metadata.metadata().empty());
  }

  {
    HloSharding sharding = HloSharding::AssignDevice(7, SingleMetadata());
    auto sharding_no_metadata = sharding.WithoutMetadata();
    EXPECT_TRUE(sharding_no_metadata.metadata().empty());
  }

  {
    HloSharding sharding = HloSharding::Tuple(
        ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}),
                                   ShapeUtil::MakeShape(U32, {7, 25}),
                                   ShapeUtil::MakeShape(S32, {9, 11})}),
        {HloSharding::Replicate(SingleMetadata()),
         HloSharding::Tile(Array2D<int64_t>({{3, 5}})),
         HloSharding::AssignDevice(3, ListMetadata())});
    auto sharding_no_metadata = sharding.WithoutMetadata();
    EXPECT_TRUE(sharding_no_metadata.metadata().empty());
    ASSERT_TRUE(sharding_no_metadata.IsTuple());
    EXPECT_EQ(sharding_no_metadata.tuple_elements().size(), 3);
    for (const auto& sub_sharding : sharding_no_metadata.tuple_elements()) {
      EXPECT_TRUE(sub_sharding.metadata().empty());
    }
  }
}

}  // namespace
}  // namespace xla
