/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <array>
#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "xla/pjrt/gpu/tfrt/gpu_event.h"
#include "xla/pjrt/gpu/tfrt/tfrt_gpu_client.h"
#include "xla/pjrt/gpu/tfrt/tfrt_gpu_device.h"
#include "xla/pjrt/gpu/tfrt/thread_checker.h"
#include "xla/pjrt/gpu/tfrt/tracked_gpu_device_buffer.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/async_value.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/threadpool.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/casts.h"

namespace xla {
namespace {

using ::tsl::thread::ThreadPool;

class TfrtGpuBufferTest : public ::testing::Test {
 protected:
  void SetUp() override {
    ASSERT_OK_AND_ASSIGN(client_, GetTfrtGpuClient(GpuClientOptions()));
    tfrt_gpu_client_ = tensorflow::down_cast<TfrtGpuClient*>(client_.get());
  }

  TfrtGpuThreadChecker thread_checker_;
  std::unique_ptr<PjRtClient> client_;
  TfrtGpuClient* tfrt_gpu_client_;
};

TEST_F(TfrtGpuBufferTest, CreateBuffer) {
  Shape on_device_shape = ShapeUtil::MakeShapeWithType<int32_t>({4, 4});
  TfrtGpuDevice* device =
      tensorflow::down_cast<TfrtGpuDevice*>(client_->devices()[0]);
  auto size_in_bytes = ShapeUtil::ByteSizeOf(on_device_shape);
  TF_ASSERT_OK_AND_ASSIGN(
      auto device_buffer,
      GpuDeviceMemory::Allocate(tfrt_gpu_client_->allocator(),
                                device->local_device_id().value(),
                                size_in_bytes));
  auto buffer_async_value_ref =
      tsl::MakeAvailableAsyncValueRef<GpuDeviceMemory>(
          std::move(device_buffer));
  auto tracked_device_buffer = std::make_unique<TrackedGpuDeviceBuffer>(
      std::move(buffer_async_value_ref),
      tsl::MakeAvailableAsyncValueRef<GpuEvent>(),
      tsl::MakeAvailableAsyncValueRef<GpuEvent>());
  auto memory_space = device->default_memory_space().value();
  auto buffer = std::make_unique<TfrtGpuBuffer>(
      on_device_shape, std::move(tracked_device_buffer), tfrt_gpu_client_,
      device, memory_space);

  EXPECT_EQ(buffer->on_device_shape(), on_device_shape);
  EXPECT_EQ(buffer->device(), device);
  EXPECT_EQ(buffer->client(), client_.get());
  EXPECT_EQ(buffer->memory_space(), memory_space);
  EXPECT_EQ(buffer->GetOnDeviceSizeInBytes().value(), size_in_bytes);
}

TEST_F(TfrtGpuBufferTest, AcquireExternalReference) {
  Shape on_device_shape = ShapeUtil::MakeShapeWithType<int32_t>({4, 4});
  TfrtGpuDevice* device =
      tensorflow::down_cast<TfrtGpuDevice*>(client_->devices()[0]);
  auto size_in_bytes = ShapeUtil::ByteSizeOf(on_device_shape);
  TF_ASSERT_OK_AND_ASSIGN(
      auto device_buffer,
      GpuDeviceMemory::Allocate(tfrt_gpu_client_->allocator(),
                                device->local_device_id().value(),
                                size_in_bytes));
  auto buffer_async_value_ref =
      tsl::MakeAvailableAsyncValueRef<GpuDeviceMemory>(
          std::move(device_buffer));
  auto definition_event = tsl::MakeConstructedAsyncValueRef<GpuEvent>();
  auto tracked_device_buffer = std::make_unique<TrackedGpuDeviceBuffer>(
      std::move(buffer_async_value_ref), definition_event, definition_event);
  auto memory_space = device->default_memory_space().value();
  auto buffer = std::make_unique<TfrtGpuBuffer>(
      on_device_shape, std::move(tracked_device_buffer), tfrt_gpu_client_,
      device, memory_space);

  ThreadPool thread_pool(tsl::Env::Default(), "gpu_buffer_test",
                         /*num_threads=*/4);

  absl::StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>> ref_status;
  auto ref_acquired_event = tsl::MakeConstructedAsyncValueRef<GpuEvent>();
  thread_pool.Schedule([&]() {
    ref_status = buffer->AcquireExternalReference();
    ref_acquired_event.SetStateConcrete();
  });
  // AcquireExternalReference should block until the definition event is
  // triggered.
  absl::SleepFor(absl::Milliseconds(100));
  EXPECT_FALSE(ref_acquired_event.IsAvailable());

  // Trigger the definition event. AcquireExternalReference should be unblocked.
  definition_event.SetStateConcrete();
  BlockUntilReady(ref_acquired_event.GetAsyncValue());
  EXPECT_OK(ref_status);

  // TODO(b/382117736): external reference should block donation.
}

TEST_F(TfrtGpuBufferTest, ReleaseDeviceMemoryOwnershipNoWait) {
  Shape on_device_shape = ShapeUtil::MakeShapeWithType<int32_t>({4, 4});
  TfrtGpuDevice* device =
      tensorflow::down_cast<TfrtGpuDevice*>(client_->devices()[0]);
  auto size_in_bytes = ShapeUtil::ByteSizeOf(on_device_shape);
  TF_ASSERT_OK_AND_ASSIGN(
      auto device_buffer,
      GpuDeviceMemory::Allocate(tfrt_gpu_client_->allocator(),
                                device->local_device_id().value(),
                                size_in_bytes));
  void* device_memory_opaque = device_buffer.buffer().opaque();
  auto buffer_async_value_ref =
      tsl::MakeAvailableAsyncValueRef<GpuDeviceMemory>(
          std::move(device_buffer));

  auto definition_event = tsl::MakeConstructedAsyncValueRef<GpuEvent>();
  bool destructed = false;
  auto tracked_device_buffer = std::make_unique<TrackedGpuDeviceBuffer>(
      std::move(buffer_async_value_ref), definition_event, definition_event,
      [&] { destructed = true; });

  auto usage_event = tsl::MakeConstructedAsyncValueRef<GpuEvent>();
  std::array usage_events{usage_event.CopyRef()};
  tracked_device_buffer->AddUsageEvents(absl::MakeSpan(usage_events));

  auto memory_space = device->default_memory_space().value();
  auto buffer = std::make_unique<TfrtGpuBuffer>(
      on_device_shape, std::move(tracked_device_buffer), tfrt_gpu_client_,
      device, memory_space);

  // Release and don't wait for definition or usage events to complete.
  auto ref_status = buffer->ReleaseDeviceMemoryOwnership(
      /*wait_for_operations_to_complete=*/false);
  EXPECT_OK(ref_status);
  auto ref = std::move(ref_status).value();
  EXPECT_EQ(device_memory_opaque, ref->OpaqueDeviceMemoryDataPointer());

  // Release again should return nullptr.
  auto ref_status_2 = buffer->ReleaseDeviceMemoryOwnership(
      /*wait_for_operations_to_complete=*/false);
  EXPECT_OK(ref_status_2);
  EXPECT_EQ(nullptr, ref_status_2.value().get());
}

TEST_F(TfrtGpuBufferTest, ReleaseDeviceMemoryOwnershipWait) {
  Shape on_device_shape = ShapeUtil::MakeShapeWithType<int32_t>({4, 4});
  TfrtGpuDevice* device =
      tensorflow::down_cast<TfrtGpuDevice*>(client_->devices()[0]);
  auto size_in_bytes = ShapeUtil::ByteSizeOf(on_device_shape);
  TF_ASSERT_OK_AND_ASSIGN(
      auto device_buffer,
      GpuDeviceMemory::Allocate(tfrt_gpu_client_->allocator(),
                                device->local_device_id().value(),
                                size_in_bytes));
  void* device_memory_opaque = device_buffer.buffer().opaque();
  auto buffer_async_value_ref =
      tsl::MakeAvailableAsyncValueRef<GpuDeviceMemory>(
          std::move(device_buffer));

  auto definition_event = tsl::MakeConstructedAsyncValueRef<GpuEvent>();
  bool destructed = false;
  auto tracked_device_buffer = std::make_unique<TrackedGpuDeviceBuffer>(
      std::move(buffer_async_value_ref), definition_event, definition_event,
      [&] { destructed = true; });

  auto usage_event = tsl::MakeConstructedAsyncValueRef<GpuEvent>();
  std::array usage_events{usage_event.CopyRef()};
  tracked_device_buffer->AddUsageEvents(absl::MakeSpan(usage_events));

  auto memory_space = device->default_memory_space().value();
  auto buffer = std::make_unique<TfrtGpuBuffer>(
      on_device_shape, std::move(tracked_device_buffer), tfrt_gpu_client_,
      device, memory_space);

  ThreadPool thread_pool(tsl::Env::Default(), "gpu_buffer_test",
                         /*num_threads=*/4);

  absl::StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>> ref_status;
  auto ref_acquired_event = tsl::MakeConstructedAsyncValueRef<GpuEvent>();
  thread_pool.Schedule([&]() {
    ref_status = buffer->ReleaseDeviceMemoryOwnership(
        /*wait_for_operations_to_complete=*/true);
    ref_acquired_event.SetStateConcrete();
  });
  // AcquireExternalReference should block until the definition event is
  // triggered.
  absl::SleepFor(absl::Milliseconds(100));
  EXPECT_FALSE(ref_acquired_event.IsAvailable());

  // Trigger the definition event.
  definition_event.SetStateConcrete();
  EXPECT_FALSE(ref_acquired_event.IsAvailable());

  // Trigger the usage event.
  usage_event.SetStateConcrete();
  BlockUntilReady(ref_acquired_event.GetAsyncValue());
  EXPECT_OK(ref_status);

  // TODO(b/382117736): should also block until donation event is triggered.
  auto ref = std::move(ref_status).value();
  EXPECT_EQ(device_memory_opaque, ref->OpaqueDeviceMemoryDataPointer());

  // Release again should return nullptr.
  auto ref_status_2 = buffer->ReleaseDeviceMemoryOwnership(
      /*wait_for_operations_to_complete=*/false);
  EXPECT_OK(ref_status_2);
  EXPECT_EQ(nullptr, ref_status_2.value().get());
}

TEST_F(TfrtGpuBufferTest, Delete) {
  Shape on_device_shape = ShapeUtil::MakeShapeWithType<int32_t>({4, 4});
  TfrtGpuDevice* device =
      tensorflow::down_cast<TfrtGpuDevice*>(client_->devices()[0]);
  auto size_in_bytes = ShapeUtil::ByteSizeOf(on_device_shape);
  TF_ASSERT_OK_AND_ASSIGN(
      auto device_buffer,
      GpuDeviceMemory::Allocate(tfrt_gpu_client_->allocator(),
                                device->local_device_id().value(),
                                size_in_bytes));
  auto buffer_async_value_ref =
      tsl::MakeAvailableAsyncValueRef<GpuDeviceMemory>(
          std::move(device_buffer));

  auto definition_event = tsl::MakeConstructedAsyncValueRef<GpuEvent>();
  bool destructed = false;
  auto tracked_device_buffer = std::make_unique<TrackedGpuDeviceBuffer>(
      std::move(buffer_async_value_ref), definition_event, definition_event,
      [&] { destructed = true; });

  auto usage_event = tsl::MakeConstructedAsyncValueRef<GpuEvent>();
  std::array usage_events{usage_event.CopyRef()};
  tracked_device_buffer->AddUsageEvents(absl::MakeSpan(usage_events));

  auto memory_space = device->default_memory_space().value();
  auto buffer = std::make_unique<TfrtGpuBuffer>(
      on_device_shape, std::move(tracked_device_buffer), tfrt_gpu_client_,
      device, memory_space);

  // Delete the buffer. The underlying device memory should not be freed until
  // the usage event is triggered.
  buffer->Delete();
  EXPECT_TRUE(buffer->IsDeleted());
  absl::SleepFor(absl::Milliseconds(50));
  EXPECT_FALSE(destructed);

  definition_event.SetStateConcrete();
  EXPECT_FALSE(destructed);

  // TODO(b/382117736): should also wait for donation event.

  usage_event.SetStateConcrete();
  EXPECT_TRUE(destructed);
}

TEST_F(TfrtGpuBufferTest, IsDeviceShapeWhenStaticShape) {
  std::vector<int32_t> data{1, 2, 3, 4, 5, 6};
  for (PrimitiveType t : {F32, F16, S8, BF16}) {
    Shape shape = ShapeUtil::MakeShape(t, {3, 2});
    TF_ASSERT_OK_AND_ASSIGN(
        std::unique_ptr<PjRtBuffer> buffer,
        client_->BufferFromHostBuffer(
            data.data(), shape.element_type(), shape.dimensions(),
            /*byte_strides=*/std::nullopt,
            PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr,
            client_->memory_spaces()[0], /*device_layout=*/nullptr));
    EXPECT_EQ(buffer->on_device_shape(), shape);
    EXPECT_EQ(*buffer->logical_on_device_shape(), shape);
  }
}

TEST_F(TfrtGpuBufferTest, CopyPoisonedBuffer) {
  Shape shape = ShapeUtil::MakeShape(F32, {8});
  const char* errmsg = "injected error";

  for (auto src_memory_space : client_->memory_spaces()) {
    for (auto dst_memory_space : client_->memory_spaces()) {
      TF_ASSERT_OK_AND_ASSIGN(auto src_buffer, client_->CreateErrorBuffer(
                                                   absl::InternalError(errmsg),
                                                   shape, src_memory_space));

      TF_ASSERT_OK_AND_ASSIGN(auto dst_buffer,
                              src_buffer->CopyToMemorySpace(dst_memory_space));

      EXPECT_THAT(
          dst_buffer->GetReadyFuture().Await(),
          testing::status::StatusIs(absl::StatusCode::kInternal, errmsg));
    }
  }
}

// TODO: b/382117736 - Add test for logical shape when shape is dynamic after
// TfrtGpuClient::Execute() is ready.

}  // namespace
}  // namespace xla
