/* Copyright 2023 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <algorithm>
#include <cstdint>
#include <utility>
#include <vector>

#include <gtest/gtest.h>
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/strings/ascii.h"
#include "absl/types/span.h"
#include "xla/service/platform_util.h"
#include "xla/stream_executor/command_buffer.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/device_address.h"
#include "xla/stream_executor/gpu/gpu_test_kernels.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/kernel_spec.h"
#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/stream_executor/semantic_version.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/stream_executor/trace_command_buffer_factory.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h"
#include "xla/tsl/platform/test_benchmark.h"

namespace stream_executor::gpu {

static Platform* GpuPlatform() {
  auto name = absl::AsciiStrToUpper(
      xla::PlatformUtil::CanonicalPlatformName("gpu").value());
  return PlatformManager::PlatformWithName(name).value();
}

static constexpr auto nested = CommandBuffer::Mode::kNested;    // NOLINT
static constexpr auto primary = CommandBuffer::Mode::kPrimary;  // NOLINT

// Some of the tests rely on CUDA 12.3+ features.
static bool IsAtLeastCuda12300(
    const stream_executor::StreamExecutor* executor) {
  if (executor->GetPlatform()->id() != cuda::kCudaPlatformId) {
    return false;
  }
  if (std::min({executor->GetDeviceDescription().runtime_version(),
                executor->GetDeviceDescription().driver_version()}) <
      SemanticVersion{12, 3, 0}) {
    return false;
  }
  return true;
}

absl::StatusOr<std::vector<const CommandBuffer::Command*>> Wrap(
    absl::StatusOr<const CommandBuffer::Command*> command) {
  TF_RETURN_IF_ERROR(command.status());
  return std::vector<const CommandBuffer::Command*>{*command};
}

TEST(GpuCommandBufferTest, LaunchSingleKernel) {
  Platform* platform = GpuPlatform();
  StreamExecutor* executor = platform->ExecutorForDevice(0).value();

  TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream());
  TF_ASSERT_OK_AND_ASSIGN(auto add, LoadAddI32TestKernel(executor));

  int64_t length = 4;
  int64_t byte_length = sizeof(int32_t) * length;

  // Prepare arguments: a=1, b=2, c=0
  DeviceAddress<int32_t> a = executor->AllocateArray<int32_t>(length, 0);
  DeviceAddress<int32_t> b = executor->AllocateArray<int32_t>(length, 0);
  DeviceAddress<int32_t> c = executor->AllocateArray<int32_t>(length, 0);

  TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length));
  TF_ASSERT_OK(stream->Memset32(&b, 2, byte_length));
  TF_ASSERT_OK(stream->MemZero(&c, byte_length));

  // Create a command buffer with a single kernel launch.
  TF_ASSERT_OK_AND_ASSIGN(auto cmd_buffer,
                          executor->CreateCommandBuffer(primary));
  TF_ASSERT_OK_AND_ASSIGN(
      auto* launch,
      cmd_buffer->CreateLaunch(add, ThreadDim(), BlockDim(4), {}, a, b, c));
  TF_ASSERT_OK(cmd_buffer->Finalize());

  TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));

  // Copy `c` data back to host.
  std::vector<int32_t> dst(4, 42);
  TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length));

  std::vector<int32_t> expected = {3, 3, 3, 3};
  ASSERT_EQ(dst, expected);

  // Prepare argument for graph update: d = 0
  DeviceAddress<int32_t> d = executor->AllocateArray<int32_t>(length, 0);
  TF_ASSERT_OK(stream->MemZero(&d, byte_length));

  // Update command buffer to write into `d` buffer.
  TF_ASSERT_OK(cmd_buffer->Update());
  TF_ASSERT_OK(
      cmd_buffer->UpdateLaunch(launch, add, ThreadDim(), BlockDim(4), a, b, d));
  TF_ASSERT_OK(cmd_buffer->Finalize());

  TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));

  // Copy `d` data back to host.
  std::fill(dst.begin(), dst.end(), 42);
  TF_ASSERT_OK(stream->Memcpy(dst.data(), d, byte_length));
  ASSERT_EQ(dst, expected);
}

TEST(GpuCommandBufferTest, TraceSingleKernel) {
  Platform* platform = GpuPlatform();
  StreamExecutor* executor = platform->ExecutorForDevice(0).value();

  if (!IsAtLeastCuda12300(executor)) {
    GTEST_SKIP() << "Command buffer tracing is not supported";
  }

  TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream());
  TF_ASSERT_OK_AND_ASSIGN(auto add, LoadAddI32Ptrs3TestKernel(executor));

  int64_t length = 4;
  int64_t byte_length = sizeof(int32_t) * length;

  // Prepare arguments: a=1, b=2, c=0
  DeviceAddress<int32_t> a = executor->AllocateArray<int32_t>(length, 0);
  DeviceAddress<int32_t> b = executor->AllocateArray<int32_t>(length, 0);
  DeviceAddress<int32_t> c = executor->AllocateArray<int32_t>(length, 0);

  TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length));
  TF_ASSERT_OK(stream->Memset32(&b, 2, byte_length));
  TF_ASSERT_OK(stream->MemZero(&c, byte_length));

  // Use an array of device memory base pointers as argument to test packing.
  KernelArgsDeviceAddressArray args({a, b, c}, 0);

  // Create a command buffer by tracing kernel launch operations.
  TF_ASSERT_OK_AND_ASSIGN(auto cmd_buffer, TraceCommandBufferFactory::Create(
                                               executor,
                                               [&](Stream* stream) {
                                                 return add->Launch(
                                                     ThreadDim(), BlockDim(4),
                                                     stream, args);
                                               },
                                               primary));

  TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));

  // Copy data back to host.
  std::vector<int32_t> dst(4, 42);
  TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length));

  std::vector<int32_t> expected = {3, 3, 3, 3};
  ASSERT_EQ(dst, expected);
}

TEST(GpuCommandBufferTest, LaunchNestedCommandBuffer) {
  Platform* platform = GpuPlatform();
  StreamExecutor* executor = platform->ExecutorForDevice(0).value();

  TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream());

  TF_ASSERT_OK_AND_ASSIGN(auto add, LoadAddI32TestKernel(executor));

  int64_t length = 4;
  int64_t byte_length = sizeof(int32_t) * length;

  // Prepare arguments: a=1, b=2, c=0
  DeviceAddress<int32_t> a = executor->AllocateArray<int32_t>(length, 0);
  DeviceAddress<int32_t> b = executor->AllocateArray<int32_t>(length, 0);
  DeviceAddress<int32_t> c = executor->AllocateArray<int32_t>(length, 0);

  TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length));
  TF_ASSERT_OK(stream->Memset32(&b, 2, byte_length));
  TF_ASSERT_OK(stream->MemZero(&c, byte_length));

  // Create a command buffer with a single kernel launch.
  TF_ASSERT_OK_AND_ASSIGN(auto primary_cmd,
                          executor->CreateCommandBuffer(primary));
  TF_ASSERT_OK_AND_ASSIGN(auto nested_cmd,
                          executor->CreateCommandBuffer(nested));
  TF_ASSERT_OK(
      nested_cmd->CreateLaunch(add, ThreadDim(), BlockDim(4), {}, a, b, c));
  TF_ASSERT_OK_AND_ASSIGN(auto* nested_command,
                          primary_cmd->CreateChildCommand(*nested_cmd, {}));
  TF_ASSERT_OK(primary_cmd->Finalize());

  TF_ASSERT_OK(primary_cmd->Submit(stream.get()));

  // Copy `c` data back to host.
  std::vector<int32_t> dst(4, 42);
  TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length));

  std::vector<int32_t> expected = {3, 3, 3, 3};
  ASSERT_EQ(dst, expected);

  // Prepare argument for graph update: d = 0
  DeviceAddress<int32_t> d = executor->AllocateArray<int32_t>(length, 0);
  TF_ASSERT_OK(stream->MemZero(&d, byte_length));

  // Update command buffer to write into `d` buffer by creating a new nested
  // command buffer.
  nested_cmd = executor->CreateCommandBuffer(nested).value();
  TF_ASSERT_OK(
      nested_cmd->CreateLaunch(add, ThreadDim(), BlockDim(4), {}, a, b, d));
  TF_ASSERT_OK(primary_cmd->Update());
  TF_ASSERT_OK(primary_cmd->UpdateChildCommand(nested_command, *nested_cmd));
  TF_ASSERT_OK(primary_cmd->Finalize());

  TF_ASSERT_OK(primary_cmd->Submit(stream.get()));

  // Copy `d` data back to host.
  std::fill(dst.begin(), dst.end(), 42);
  TF_ASSERT_OK(stream->Memcpy(dst.data(), d, byte_length));
  ASSERT_EQ(dst, expected);
}

TEST(GpuCommandBufferTest, MemcpyDeviceToDevice) {
  Platform* platform = GpuPlatform();
  StreamExecutor* executor = platform->ExecutorForDevice(0).value();

  TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream());

  int64_t length = 4;
  int64_t byte_length = sizeof(int32_t) * length;

  // Prepare arguments: a=42, b=uninitialized
  DeviceAddress<int32_t> a = executor->AllocateArray<int32_t>(length, 0);
  DeviceAddress<int32_t> b = executor->AllocateArray<int32_t>(length, 0);

  TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length));

  // Create a command buffer with a single a to b memcpy command.
  TF_ASSERT_OK_AND_ASSIGN(auto cmd_buffer,
                          executor->CreateCommandBuffer(primary));
  TF_ASSERT_OK_AND_ASSIGN(auto* memcpy,
                          cmd_buffer->CreateMemcpyD2D(&b, a, byte_length, {}));
  TF_ASSERT_OK(cmd_buffer->Finalize());

  TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));

  // Copy `b` data back to host.
  std::vector<int32_t> dst(4, 0);
  TF_ASSERT_OK(stream->Memcpy(dst.data(), a, byte_length));

  std::vector<int32_t> expected = {42, 42, 42, 42};
  ASSERT_EQ(dst, expected);

  // Update command buffer to swap the memcpy direction.
  TF_ASSERT_OK(cmd_buffer->Update());
  TF_ASSERT_OK(cmd_buffer->UpdateMemcpyD2D(memcpy, &a, b, byte_length));
  TF_ASSERT_OK(cmd_buffer->Finalize());

  // Clear destination to test that command buffer actually copied memory.
  TF_ASSERT_OK(stream->Memset32(&a, 0, byte_length));

  TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));

  // Copy `a` data back to host.
  std::fill(dst.begin(), dst.end(), 0);
  TF_ASSERT_OK(stream->Memcpy(dst.data(), a, byte_length));
  ASSERT_EQ(dst, expected);
}

TEST(GpuCommandBufferTest, Memset) {
  Platform* platform = GpuPlatform();
  StreamExecutor* executor = platform->ExecutorForDevice(0).value();

  TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream());

  int64_t length = 4;
  int64_t byte_length = sizeof(int32_t) * length;

  DeviceAddress<int32_t> a = executor->AllocateArray<int32_t>(length, 0);

  // Create a command buffer with a single memset command.
  auto cmd_buffer = executor->CreateCommandBuffer(primary).value();

  TF_ASSERT_OK_AND_ASSIGN(
      const CommandBuffer::Command* memset,
      cmd_buffer->CreateMemset(&a, uint32_t{42}, length, {}));
  TF_ASSERT_OK(cmd_buffer->Finalize());

  TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));

  // Copy `a` data back to host.
  std::vector<int32_t> dst(4, 0);
  TF_ASSERT_OK(stream->Memcpy(dst.data(), a, byte_length));

  std::vector<int32_t> expected = {42, 42, 42, 42};
  ASSERT_EQ(dst, expected);

  // Update command buffer to use a new bit pattern.
  TF_ASSERT_OK(cmd_buffer->Update());
  TF_ASSERT_OK(cmd_buffer->UpdateMemset(memset, &a, uint32_t{43}, length));
  TF_ASSERT_OK(cmd_buffer->Finalize());

  TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));

  // Copy `d` data back to host.
  std::fill(dst.begin(), dst.end(), 0);
  TF_ASSERT_OK(stream->Memcpy(dst.data(), a, byte_length));

  expected = {43, 43, 43, 43};
  ASSERT_EQ(dst, expected);
}

TEST(GpuCommandBufferTest, ConditionalCaseEmptyGraph) {
  Platform* platform = GpuPlatform();
  StreamExecutor* executor = platform->ExecutorForDevice(0).value();

  // See b/362769658.
  if (!IsAtLeastCuda12300(executor)) {
    GTEST_SKIP() << "CUDA graph conditionals are not supported";
  }

  TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream());
  TF_ASSERT_OK_AND_ASSIGN(auto add, LoadAddI32TestKernel(executor));

  int64_t length = 4;
  int64_t byte_length = sizeof(int32_t) * length;

  // Prepare arguments: a=2, b=3, c=0, index=0
  DeviceAddress<int32_t> index = executor->AllocateArray<int32_t>(1, 0);
  DeviceAddress<int32_t> a = executor->AllocateArray<int32_t>(length, 0);
  DeviceAddress<int32_t> b = executor->AllocateArray<int32_t>(length, 0);
  DeviceAddress<int32_t> c = executor->AllocateArray<int32_t>(length, 0);

  TF_ASSERT_OK(stream->Memset32(&index, 0, sizeof(int32_t)));
  TF_ASSERT_OK(stream->Memset32(&a, 2, byte_length));
  TF_ASSERT_OK(stream->Memset32(&b, 3, byte_length));
  TF_ASSERT_OK(stream->MemZero(&c, byte_length));

  // if (index == 0) c = a + b
  CommandBuffer::CreateCommands branch0 = [&](CommandBuffer* b0, auto deps) {
    return Wrap(b0->CreateLaunch(add, ThreadDim(), BlockDim(4), deps, a, b, c));
  };

  // if (index == 1) <empty graph>
  CommandBuffer::CreateCommands branch1 = [&](CommandBuffer*, auto deps) {
    return std::vector<const CommandBuffer::Command*>{};
  };

  std::vector<CommandBuffer::CreateCommands> branches;
  branches.push_back(std::move(branch0));
  branches.push_back(std::move(branch1));

  // Create a command buffer with a single conditional operation.
  TF_ASSERT_OK_AND_ASSIGN(auto cmd_buffer,
                          executor->CreateCommandBuffer(primary));
  TF_ASSERT_OK(cmd_buffer->CreateCase(index, std::move(branches), {}));
  TF_ASSERT_OK(cmd_buffer->Finalize());

  TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));
  TF_ASSERT_OK(stream->BlockHostUntilDone());

  // Copy `c` data back to host.
  std::vector<int32_t> dst(4, 42);
  TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length));

  std::vector<int32_t> expected_add = {5, 5, 5, 5};
  ASSERT_EQ(dst, expected_add);

  // Set index to `1`
  TF_ASSERT_OK(stream->Memset32(&index, 1, sizeof(int32_t)));

  // Submit the same command buffer, but this time it should take the empty path
  // and do nothing.
  TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));
  TF_ASSERT_OK(stream->BlockHostUntilDone());

  TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length));
  ASSERT_EQ(dst, expected_add);

  // Set index to `-1` (out of bound index value).
  TF_ASSERT_OK(stream->Memset32(&index, -1, sizeof(int32_t)));

  TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));
  TF_ASSERT_OK(stream->BlockHostUntilDone());

  TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length));
  ASSERT_EQ(dst, expected_add);

  // Set index to `2` (out of bound index value).
  TF_ASSERT_OK(stream->Memset32(&index, 2, sizeof(int32_t)));

  TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));
  TF_ASSERT_OK(stream->BlockHostUntilDone());

  TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length));
  ASSERT_EQ(dst, expected_add);
}

class GpuCommandBufferCaseTest : public testing::TestWithParam<int> {
 protected:
  int GetNumCases() { return GetParam(); }

  int GetEffectiveIndex(int i) {
    return (i < 0 || i >= GetNumCases()) ? GetNumCases() - 1 : i;
  }
};

TEST_P(GpuCommandBufferCaseTest, ConditionalMultiCase) {
  Platform* platform = GpuPlatform();
  StreamExecutor* executor = platform->ExecutorForDevice(0).value();

  if (!IsAtLeastCuda12300(executor)) {
    GTEST_SKIP() << "CUDA graph conditionals are not supported";
  }

  TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream());
  TF_ASSERT_OK_AND_ASSIGN(auto mul, LoadMulI32TestKernel(executor));

  constexpr int64_t kLength = 1;
  int64_t byte_length = sizeof(int32_t) * kLength;

  // Prepare arguments: index=0
  DeviceAddress<int32_t> index = executor->AllocateArray<int32_t>(1, 0);
  TF_ASSERT_OK(stream->Memset32(&index, 0, sizeof(int32_t)));

  const int kNumCases = GetNumCases();
  std::vector<DeviceAddress<int32_t>> values;
  std::vector<DeviceAddress<int32_t>> results;
  std::vector<CommandBuffer::CreateCommands> branches;
  values.resize(kNumCases);
  results.resize(kNumCases);
  branches.resize(kNumCases);
  for (int i = 0; i < kNumCases; ++i) {
    values[i] = executor->AllocateArray<int32_t>(kLength, 0);
    TF_ASSERT_OK(stream->Memset32(&values[i], i, byte_length));
    results[i] = executor->AllocateArray<int32_t>(kLength, 0);
    TF_ASSERT_OK(stream->Memset32(&results[i], 0, byte_length));
    branches[i] = [&, i](CommandBuffer* branch_cmd, auto dependencies) {
      // result = i * i;
      return Wrap(branch_cmd->CreateLaunch(mul, ThreadDim(), BlockDim(kLength),
                                           dependencies, values[i], values[i],
                                           results[i]));
    };
  }

  // Create a command buffer with a single conditional operation.
  TF_ASSERT_OK_AND_ASSIGN(auto cmd_buffer,
                          executor->CreateCommandBuffer(primary));
  TF_ASSERT_OK(cmd_buffer->CreateCase(index, std::move(branches), {}));
  TF_ASSERT_OK(cmd_buffer->Finalize());

  // We test the out of bounds cases as well ( i < 0, i >= kNumCases).
  for (int i = -1; i <= kNumCases; ++i) {
    // Set index.
    TF_ASSERT_OK(stream->Memset32(&index, i, sizeof(int32_t)));

    // Submit case.
    TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));
    TF_ASSERT_OK(stream->BlockHostUntilDone());

    int effective_index = GetEffectiveIndex(i);

    // Check all results are 0 except case index submitted.
    for (int z = 0; z < kNumCases; ++z) {
      std::vector<int32_t> dst(kLength, 42);
      TF_ASSERT_OK(stream->Memcpy(dst.data(), results[z], byte_length));

      // Build expected result vector.
      std::vector<int32_t> expected;
      expected.resize(kLength);
      for (int p = 0; p < kLength; ++p) {
        if (effective_index == z) {
          expected[p] = effective_index * effective_index;
        } else {
          expected[p] = 0;
        }
      }

      ASSERT_EQ(dst, expected)
          << "For result " << z << " after running case " << i;
      TF_ASSERT_OK(stream->Memset32(&results[z], 0, byte_length));
    }
  }
}

INSTANTIATE_TEST_SUITE_P(ConditionalMultipleCaseTest, GpuCommandBufferCaseTest,
                         testing::Range(1, 32),
                         testing::PrintToStringParamName());

TEST(GpuCommandBufferTest, ConditionalCase) {
  Platform* platform = GpuPlatform();
  StreamExecutor* executor = platform->ExecutorForDevice(0).value();

  if (!IsAtLeastCuda12300(executor)) {
    GTEST_SKIP() << "CUDA graph conditionals are not supported";
  }

  TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream());
  TF_ASSERT_OK_AND_ASSIGN(auto add, LoadAddI32TestKernel(executor));
  TF_ASSERT_OK_AND_ASSIGN(auto mul, LoadMulI32TestKernel(executor));

  int64_t length = 4;
  int64_t byte_length = sizeof(int32_t) * length;

  // Prepare arguments: a=2, b=3, c=0, index=0
  DeviceAddress<int32_t> index = executor->AllocateArray<int32_t>(1, 0);
  DeviceAddress<int32_t> a = executor->AllocateArray<int32_t>(length, 0);
  DeviceAddress<int32_t> b = executor->AllocateArray<int32_t>(length, 0);
  DeviceAddress<int32_t> c = executor->AllocateArray<int32_t>(length, 0);

  TF_ASSERT_OK(stream->Memset32(&index, 0, sizeof(int32_t)));
  TF_ASSERT_OK(stream->Memset32(&a, 2, byte_length));
  TF_ASSERT_OK(stream->Memset32(&b, 3, byte_length));
  TF_ASSERT_OK(stream->MemZero(&c, byte_length));

  // if (index == 0) c = a + b
  CommandBuffer::CreateCommands branch0 = [&](CommandBuffer* b0, auto deps) {
    return Wrap(b0->CreateLaunch(add, ThreadDim(), BlockDim(4), deps, a, b, c));
  };

  // if (index == 1) c = a * b
  CommandBuffer::CreateCommands branch1 = [&](CommandBuffer* b1, auto deps) {
    return Wrap(b1->CreateLaunch(mul, ThreadDim(), BlockDim(4), deps, a, b, c));
  };

  std::vector<CommandBuffer::CreateCommands> branches;
  branches.push_back(std::move(branch0));
  branches.push_back(std::move(branch1));

  // Create a command buffer with a single conditional operation.
  auto cmd_buffer = executor->CreateCommandBuffer(primary).value();
  TF_ASSERT_OK(cmd_buffer->CreateCase(index, std::move(branches), {}));
  TF_ASSERT_OK(cmd_buffer->Finalize());

  TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));
  TF_ASSERT_OK(stream->BlockHostUntilDone());

  // Copy `c` data back to host.
  std::vector<int32_t> dst(4, 42);
  TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length));

  std::vector<int32_t> expected_add = {5, 5, 5, 5};
  ASSERT_EQ(dst, expected_add);

  // Set index to `1`
  TF_ASSERT_OK(stream->Memset32(&index, 1, sizeof(int32_t)));

  // Submit the same command buffer, but this time it should multiply inputs.
  TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));
  TF_ASSERT_OK(stream->BlockHostUntilDone());

  TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length));
  std::vector<int32_t> expected_mul = {6, 6, 6, 6};
  ASSERT_EQ(dst, expected_mul);

  // Set index to `-1` (out of bound index value).
  TF_ASSERT_OK(stream->Memset32(&index, -1, sizeof(int32_t)));

  TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));
  TF_ASSERT_OK(stream->BlockHostUntilDone());

  TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length));
  ASSERT_EQ(dst, expected_mul);

  // Set index to `2` (out of bound index value).
  TF_ASSERT_OK(stream->Memset32(&index, 2, sizeof(int32_t)));

  TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));
  TF_ASSERT_OK(stream->BlockHostUntilDone());

  TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length));
  ASSERT_EQ(dst, expected_mul);
}

TEST(GpuCommandBufferTest, ConditionalWhile) {
  Platform* platform = GpuPlatform();
  StreamExecutor* executor = platform->ExecutorForDevice(0).value();

  if (!IsAtLeastCuda12300(executor)) {
    GTEST_SKIP() << "CUDA graph conditionals are not supported";
  }

  TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream());
  TF_ASSERT_OK_AND_ASSIGN(auto add, LoadAddI32TestKernel(executor));
  TF_ASSERT_OK_AND_ASSIGN(auto inc_and_cmp, LoadCmpAndIncTestKernel(executor));

  int64_t length = 4;
  int64_t byte_length = sizeof(int32_t) * length;

  // Prepare arguments: a=1, b=0, loop_counter=0, pred=false
  // Value of `pred` is not important, as it will be updated by `cond_builder`
  // below.
  DeviceAddress<bool> pred = executor->AllocateArray<bool>(1, 0);
  DeviceAddress<int32_t> loop_counter = executor->AllocateArray<int32_t>(1, 0);
  DeviceAddress<int32_t> num_iters = executor->AllocateArray<int32_t>(1, 0);
  DeviceAddress<int32_t> a = executor->AllocateArray<int32_t>(length, 0);
  DeviceAddress<int32_t> b = executor->AllocateArray<int32_t>(length, 0);

  static constexpr bool kFalse = false;
  TF_ASSERT_OK(stream->Memcpy(&pred, &kFalse, 1));
  TF_ASSERT_OK(stream->Memset32(&loop_counter, 0, sizeof(int32_t)));
  TF_ASSERT_OK(stream->Memset32(&num_iters, 10, sizeof(int32_t)));
  TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length));
  TF_ASSERT_OK(stream->MemZero(&b, byte_length));

  // Loop cond: loop_counter++ < num_iters;
  CommandBuffer::CreateCommands create_cond = [&](CommandBuffer* cond_cmd,
                                                  auto deps) {
    return Wrap(cond_cmd->CreateLaunch(inc_and_cmp, ThreadDim(), BlockDim(), {},
                                       loop_counter, pred, num_iters));
  };

  // Loop body: b = a + b
  CommandBuffer::CreateCommands create_body = [&](CommandBuffer* body_cmd,
                                                  auto deps) {
    return Wrap(body_cmd->CreateLaunch(add, ThreadDim(), BlockDim(length), {},
                                       a, b, b));
  };

  // Create a command buffer with a single conditional operation.
  auto cmd_buffer = executor->CreateCommandBuffer(primary).value();
  TF_ASSERT_OK(cmd_buffer->CreateWhile(pred, std::move(create_cond),
                                       std::move(create_body), {}));
  TF_ASSERT_OK(cmd_buffer->Finalize());

  TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));

  // Copy `b` data back to host.
  std::vector<int32_t> dst(4, 42);
  TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length));

  std::vector<int32_t> expected = {10, 10, 10, 10};
  ASSERT_EQ(dst, expected);
}

// TODO(b/339653343): Re-enable when not failing.
TEST(GpuCommandBufferTest, DISABLED_WhileNestedConditional) {
  Platform* platform = GpuPlatform();
  StreamExecutor* executor = platform->ExecutorForDevice(0).value();

  if (!IsAtLeastCuda12300(executor)) {
    GTEST_SKIP() << "CUDA graph conditionals are not supported";
  }

  TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream());
  TF_ASSERT_OK_AND_ASSIGN(auto add, LoadAddI32TestKernel(executor));
  TF_ASSERT_OK_AND_ASSIGN(auto cmp_and_inc, LoadCmpAndIncTestKernel(executor));

  int64_t length = 4;
  int64_t byte_length = sizeof(int32_t) * length;

  // Prepare arguments: a=1, b=0, loop_counter=0, pred=false
  // Value of `pred` is not important, as it will be updated by `cond_builder`
  // below.
  DeviceAddress<bool> pred = executor->AllocateArray<bool>(1, 0);
  DeviceAddress<bool> pred_then = executor->AllocateArray<bool>(1, 0);
  DeviceAddress<int32_t> loop_counter = executor->AllocateArray<int32_t>(1, 0);
  DeviceAddress<int32_t> num_iters = executor->AllocateArray<int32_t>(1, 0);
  DeviceAddress<int32_t> a = executor->AllocateArray<int32_t>(length, 0);
  DeviceAddress<int32_t> b = executor->AllocateArray<int32_t>(length, 0);

  static constexpr bool kFalse = false;
  static constexpr bool kTrue = true;
  TF_ASSERT_OK(stream->Memcpy(&pred, &kFalse, 1));
  TF_ASSERT_OK(stream->Memcpy(&pred_then, &kTrue, 1));
  TF_ASSERT_OK(stream->Memset32(&loop_counter, 0, sizeof(int32_t)));
  TF_ASSERT_OK(stream->Memset32(&num_iters, 10, sizeof(int32_t)));
  TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length));
  TF_ASSERT_OK(stream->MemZero(&b, byte_length));

  CommandBuffer::CreateCommands create_then =
      // Then body: b = a + b
      [&](CommandBuffer* then_cmd, auto deps) {
        return Wrap(then_cmd->CreateLaunch(add, ThreadDim(), BlockDim(length),
                                           deps, a, b, b));
      };

  CommandBuffer::CreateCommands create_else =
      // Else body: b = a + b
      [&](CommandBuffer* then_cmd, auto deps) {
        return Wrap(then_cmd->CreateLaunch(add, ThreadDim(), BlockDim(length),
                                           deps, a, b, b));
      };

  std::vector<CommandBuffer::CreateCommands> branches;
  branches.push_back(std::move(create_then));
  branches.push_back(std::move(create_else));

  auto nested_cmd = executor->CreateCommandBuffer(nested).value();
  // TODO(b/339653343): Adding this Case condition causes AddNestedCommandBuffer
  // to fail.
  TF_ASSERT_OK(nested_cmd->CreateCase(pred_then, std::move(branches), {}));

  // Loop cond: loop_counter++ < num_iters;
  CommandBuffer::CreateCommands create_cond = [&](CommandBuffer* cond_cmd,
                                                  auto deps) {
    return Wrap(cond_cmd->CreateLaunch(cmp_and_inc, ThreadDim(),
                                       BlockDim(length), deps, loop_counter,
                                       pred, num_iters));
  };

  CommandBuffer::CreateCommands create_body = [&](CommandBuffer* body_cmd,
                                                  auto deps) {
    return Wrap(body_cmd->CreateChildCommand(*nested_cmd, deps));
  };

  // Create a command buffer with a single conditional operation.
  auto cmd_buffer = executor->CreateCommandBuffer(primary).value();
  TF_ASSERT_OK(cmd_buffer->CreateWhile(pred, std::move(create_cond),
                                       std::move(create_body), {}));
  TF_ASSERT_OK(cmd_buffer->Finalize());

  TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));

  // Copy `b` data back to host.
  std::vector<int32_t> dst(4, 42);
  TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length));

  std::vector<int32_t> expected = {10, 10, 10, 10};
  ASSERT_EQ(dst, expected);
}

struct TestResource : public CommandBuffer::Resource {
  TestResource() = default;
  explicit TestResource(int32_t value) : value(value) {}
  int32_t value = 0;
};

TEST(GpuCommandBufferTest, GetOrCreateResource) {
  Platform* platform = GpuPlatform();
  StreamExecutor* executor = platform->ExecutorForDevice(0).value();

  TF_ASSERT_OK_AND_ASSIGN(
      auto command_buffer,
      executor->CreateCommandBuffer(CommandBuffer::Mode::kNested));

  EXPECT_EQ(command_buffer->GetOrNullResource<TestResource>(), nullptr);

  TestResource* resource = command_buffer->GetOrCreateResource<TestResource>(
      [] { return std::make_unique<TestResource>(42); });
  EXPECT_NE(resource, nullptr);
  EXPECT_EQ(resource->value, 42);

  EXPECT_EQ(command_buffer->GetOrNullResource<TestResource>(), resource);
}

//===----------------------------------------------------------------------===//
// Performance benchmarks below
//===----------------------------------------------------------------------===//

#define BENCHMARK_SIZES(NAME) \
  BENCHMARK(NAME)->Arg(8)->Arg(32)->Arg(128)->Arg(512)->Arg(1024);

// In benchmarks we construct command buffers in nested mode when we
// do not want to measure graph executable instantiation overhead.
static void BM_CreateCommandBuffer(benchmark::State& state) {
  Platform* platform = GpuPlatform();
  StreamExecutor* executor = platform->ExecutorForDevice(0).value();
  TF_ASSERT_OK_AND_ASSIGN(auto add, LoadAddI32TestKernel(executor));

  DeviceAddress<int32_t> b = executor->AllocateArray<int32_t>(1, 0);

  for (auto s : state) {
    auto cmd_buffer = executor->CreateCommandBuffer(nested).value();
    for (int i = 1; i < state.range(0); ++i) {
      CHECK_OK(
          cmd_buffer->CreateLaunch(add, ThreadDim(), BlockDim(4), {}, b, b, b));
    }
    CHECK_OK(cmd_buffer->Finalize());
  }
}

BENCHMARK_SIZES(BM_CreateCommandBuffer);

static void BM_TraceCommandBuffer(benchmark::State& state) {
  Platform* platform = GpuPlatform();
  StreamExecutor* executor = platform->ExecutorForDevice(0).value();

  TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream());
  TF_ASSERT_OK_AND_ASSIGN(auto add, LoadAddI32TestKernel(executor));

  DeviceAddress<int32_t> b = executor->AllocateArray<int32_t>(1, 0);

  for (auto s : state) {
    auto launch_kernels = [&](Stream* stream) {
      for (int i = 1; i < state.range(0); ++i) {
        CHECK_OK(add.Launch(ThreadDim(), BlockDim(4), stream, b, b, b));
      }
      return absl::OkStatus();
    };

    CHECK_OK(
        TraceCommandBufferFactory::Create(executor, launch_kernels, nested));
  }
}

BENCHMARK_SIZES(BM_TraceCommandBuffer);

static void BM_UpdateCommandBuffer(benchmark::State& state) {
  Platform* platform = GpuPlatform();
  StreamExecutor* executor = platform->ExecutorForDevice(0).value();
  TF_ASSERT_OK_AND_ASSIGN(auto add, LoadAddI32TestKernel(executor));

  DeviceAddress<int32_t> b = executor->AllocateArray<int32_t>(1, 0);

  auto cmd_buffer = executor->CreateCommandBuffer(primary).value();
  for (int i = 1; i < state.range(0); ++i) {
    CHECK_OK(
        cmd_buffer->CreateLaunch(add, ThreadDim(), BlockDim(4), {}, b, b, b));
  }
  CHECK_OK(cmd_buffer->Finalize());

  for (auto s : state) {
    CHECK_OK(cmd_buffer->Update());
    for (int i = 1; i < state.range(0); ++i) {
      CHECK_OK(
          cmd_buffer->CreateLaunch(add, ThreadDim(), BlockDim(4), {}, b, b, b));
    }
    CHECK_OK(cmd_buffer->Finalize());
  }
}

BENCHMARK_SIZES(BM_UpdateCommandBuffer);

}  // namespace stream_executor::gpu
