/* Copyright 2023 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/pjrt/gpu/se_gpu_pjrt_compiler.h"

#include <memory>
#include <optional>
#include <string>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "absl/status/status_matchers.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Parser/Parser.h"
#include "xla/hlo/builder/xla_computation.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/hlo/testlib/test.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "xla/pjrt/gpu/se_gpu_topology_description.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h"
#include "xla/service/gpu_topology.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tsl/platform/statusor.h"

namespace xla {
namespace {

constexpr absl::string_view kFakeDeviceName = "Fake_device";

constexpr absl::string_view kProgram = R"(HloModule Computation

ENTRY Computation() -> s32[] {
  ROOT result = s32[] constant(2)
})";

constexpr absl::string_view mlir_str = R"mlir(
  module {
    func.func @main() -> tensor<i32> {
      %0 = mhlo.constant dense<2> : tensor<i32>
      return %0 : tensor<i32>
    }
  })mlir";

absl::StatusOr<xla::XlaComputation> GetXlaComputation(
    absl::string_view program) {
  TF_ASSIGN_OR_RETURN(auto hlo_module,
                      xla::ParseAndReturnUnverifiedModule(program, {}));

  return XlaComputation(hlo_module->ToProto());
}

std::shared_ptr<xla::GpuTopology> GetGpuTopology(
    absl::string_view platform_version, int num_partitions,
    int num_hosts_per_partition, int num_devices_per_host,
    int core_count_per_chip) {
  return std::make_shared<xla::GpuTopology>(platform_version, num_partitions,
                                            num_hosts_per_partition,
                                            num_devices_per_host);
}

TEST(StreamExecutorGpuCompilerTest, NoClientXla) {
  StreamExecutorGpuCompiler compiler;
  StreamExecutorGpuTopologyDescription topology(
      CudaId(), CudaName(), GetGpuTopology(kFakeDeviceName, 1, 1, 2, 10));

  TF_ASSERT_OK_AND_ASSIGN(auto computation, GetXlaComputation(kProgram));
  EXPECT_THAT(compiler.Compile(xla::CompileOptions(), computation, topology,
                               /*client=*/nullptr),
              absl_testing::StatusIs(absl::StatusCode::kUnimplemented));
}

TEST(StreamExecutorGpuCompilerTest, TopologyNotSameXla) {
  StreamExecutorGpuCompiler compiler;
  StreamExecutorGpuTopologyDescription topology(
      CudaId(), CudaName(), GetGpuTopology(kFakeDeviceName, 1, 1, 2, 10));

  TF_ASSERT_OK_AND_ASSIGN(auto client,
                          GetStreamExecutorGpuClient(GpuClientOptions()));
  TF_ASSERT_OK_AND_ASSIGN(auto computation, GetXlaComputation(kProgram));
  EXPECT_THAT(compiler.Compile(xla::CompileOptions(), computation, topology,
                               client.get()),
              absl_testing::StatusIs(absl::StatusCode::kOk));
}

TEST(StreamExecutorGpuCompilerTest, SuccessXla) {
  StreamExecutorGpuCompiler compiler;

  TF_ASSERT_OK_AND_ASSIGN(auto client,
                          GetStreamExecutorGpuClient(GpuClientOptions()));
  TF_ASSERT_OK_AND_ASSIGN(auto computation, GetXlaComputation(kProgram));
  TF_ASSERT_OK_AND_ASSIGN(
      std::unique_ptr<xla::PjRtLoadedExecutable> loaded_executable,
      client->CompileAndLoad(computation, xla::CompileOptions()));

  TF_ASSERT_OK_AND_ASSIGN(auto result,
                          loaded_executable->Execute(
                              /*argument_handles=*/{{}}, /*options=*/{}));

  ASSERT_EQ(result.size(), 1);
  std::vector<std::unique_ptr<xla::PjRtBuffer>>& result_buffers = result[0];
  ASSERT_EQ(result_buffers.size(), 1);
  TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<xla::Literal> result_literal,
                          result_buffers[0]->ToLiteral().Await());
  EXPECT_TRUE(
      LiteralTestUtil::Equal(LiteralUtil::CreateR0(2), *result_literal));
}

TEST(StreamExecutorGpuCompilerTest, NoClientMlir) {
  StreamExecutorGpuCompiler compiler;

  mlir::MLIRContext context;
  context.loadDialect<mlir::mhlo::MhloDialect, mlir::func::FuncDialect>();

  auto mlir_module =
      mlir::parseSourceString<mlir::ModuleOp>(mlir_str, &context);

  StreamExecutorGpuTopologyDescription topology(
      CudaId(), CudaName(), GetGpuTopology(kFakeDeviceName, 1, 1, 2, 10));

  EXPECT_THAT(
      compiler.Compile(xla::CompileOptions(), mlir_module.get(), topology,
                       /*client=*/nullptr),
      absl_testing::StatusIs(absl::StatusCode::kUnimplemented));
}

TEST(StreamExecutorGpuCompilerTest, TopologyNotSameMlir) {
  StreamExecutorGpuCompiler compiler;

  mlir::MLIRContext context;
  context.loadDialect<mlir::mhlo::MhloDialect, mlir::func::FuncDialect>();

  auto mlir_module =
      mlir::parseSourceString<mlir::ModuleOp>(mlir_str, &context);

  StreamExecutorGpuTopologyDescription topology(
      CudaId(), CudaName(), GetGpuTopology(kFakeDeviceName, 1, 1, 2, 10));

  TF_ASSERT_OK_AND_ASSIGN(auto client,
                          GetStreamExecutorGpuClient(GpuClientOptions()));
  EXPECT_THAT(compiler.Compile(xla::CompileOptions(), mlir_module.get(),
                               topology, client.get()),
              absl_testing::StatusIs(absl::StatusCode::kOk));
}

TEST(StreamExecutorGpuCompilerTest, SuccessMlir) {
  StreamExecutorGpuCompiler compiler;

  mlir::MLIRContext context;
  context.loadDialect<mlir::mhlo::MhloDialect, mlir::func::FuncDialect>();

  auto mlir_module =
      mlir::parseSourceString<mlir::ModuleOp>(mlir_str, &context);

  TF_ASSERT_OK_AND_ASSIGN(auto client,
                          GetStreamExecutorGpuClient(GpuClientOptions()));
  TF_ASSERT_OK_AND_ASSIGN(
      std::unique_ptr<xla::PjRtLoadedExecutable> loaded_executable,
      client->CompileAndLoad(mlir_module.get(), xla::CompileOptions()));

  TF_ASSERT_OK_AND_ASSIGN(auto result,
                          loaded_executable->Execute(
                              /*argument_handles=*/{{}}, /*options=*/{}));

  ASSERT_EQ(result.size(), 1);
  std::vector<std::unique_ptr<xla::PjRtBuffer>>& result_buffers = result[0];
  ASSERT_EQ(result_buffers.size(), 1);
  TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<xla::Literal> result_literal,
                          result_buffers[0]->ToLiteral().Await());
  EXPECT_TRUE(
      LiteralTestUtil::Equal(LiteralUtil::CreateR0(2), *result_literal));
}

TEST(StreamExecutorGpuCompilerTest, SuccessMlirCanBeSerialized) {
  StreamExecutorGpuCompiler compiler;

  mlir::MLIRContext context;
  context.loadDialect<mlir::mhlo::MhloDialect, mlir::func::FuncDialect>();

  auto mlir_module =
      mlir::parseSourceString<mlir::ModuleOp>(mlir_str, &context);

  TF_ASSERT_OK_AND_ASSIGN(auto client,
                          GetStreamExecutorGpuClient(GpuClientOptions()));

  StreamExecutorGpuTopologyDescription topology(
      CudaId(), CudaName(), GetGpuTopology(kFakeDeviceName, 1, 1, 2, 10));

  TF_ASSERT_OK_AND_ASSIGN(
      std::unique_ptr<xla::PjRtExecutable> executable,
      compiler.Compile(xla::CompileOptions(), mlir_module.get(), topology,
                       client.get()));

  TF_ASSERT_OK_AND_ASSIGN(std::string serialized,
                          executable->SerializeExecutable());
  ASSERT_FALSE(serialized.empty());

  TF_ASSERT_OK_AND_ASSIGN(auto loaded_executable_from_serialized,
                          client->LoadSerializedExecutable(
                              serialized, std::nullopt, xla::LoadOptions()));

  TF_ASSERT_OK_AND_ASSIGN(auto result,
                          loaded_executable_from_serialized->Execute(
                              /*argument_handles=*/{{}}, /*options=*/{}));

  ASSERT_EQ(result.size(), 1);
  std::vector<std::unique_ptr<xla::PjRtBuffer>>& result_buffers = result[0];
  ASSERT_EQ(result_buffers.size(), 1);
  TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<xla::Literal> result_literal,
                          result_buffers[0]->ToLiteral().Await());
  EXPECT_TRUE(
      LiteralTestUtil::Equal(LiteralUtil::CreateR0(2), *result_literal));
}

}  // namespace
}  // namespace xla
