/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http:www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <iostream>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "mlir/IR/MLIRContext.h"
#include "xla/backends/autotuner/autotuner.h"
#include "xla/backends/autotuner/autotuner_cache_interface.h"
#include "xla/backends/autotuner/codegen_backend.h"
#include "xla/backends/autotuner/profiler.h"
#include "xla/backends/gpu/autotuner/factory.h"
#include "xla/backends/gpu/autotuner/gpu_profiler.h"
#include "xla/backends/gpu/autotuner/legacy_cache.h"
#include "xla/debug_options_flags.h"
#include "xla/hlo/analysis/symbolic_expr.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/service/compiler.h"
#include "xla/service/platform_util.h"
#include "xla/stream_executor/device_address_allocator.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform/platform_object_registry.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/stream_executor/stream_executor_memory_allocator.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/threadpool.h"
#include "xla/tsl/util/command_line_flags.h"
#include "xla/xla.pb.h"
#include "tsl/platform/cpu_info.h"
#include "tsl/platform/init_main.h"

namespace {

const char* const kUsage = R"(
This tool autotunes an HLO module from a given HLO file and prints the
autotuned module to stdout.

Usage:

  bazel run autotuner_main -- --hlo_file=path/to/hlo_module \
    [--cache_dir=path/to/cache_dir] \
    [--autotune_cache_mode=READ|READ_WRITE]
)";
}  // namespace

namespace xla {
namespace gpu {
namespace {

absl::StatusOr<std::unique_ptr<HloModule>> GetModule(
    const std::string& hlo_file) {
  std::string hlo_text;
  TF_RETURN_IF_ERROR(
      tsl::ReadFileToString(tsl::Env::Default(), hlo_file, &hlo_text));
  return ParseAndReturnUnverifiedModule(hlo_text);
}

absl::Status Autotune(HloModule& module, const std::string& cache_dir,
                      const std::string& autotune_cache_mode_str,
                      mlir::MLIRContext* mlir_context) {
  TF_ASSIGN_OR_RETURN(std::string platform_name,
                      PlatformUtil::CanonicalPlatformName("gpu"));

  TF_ASSIGN_OR_RETURN(se::Platform * platform,
                      se::PlatformManager::PlatformWithName(
                          absl::AsciiStrToUpper(platform_name)));
  if (platform->VisibleDeviceCount() == 0) {
    return absl::InternalError("No devices found");
  }

  TF_ASSIGN_OR_RETURN(std::unique_ptr<Compiler> compiler,
                      xla::Compiler::GetForPlatform(platform->id()));
  se::StreamExecutor* stream_executor = platform->ExecutorForDevice(0).value();
  DebugOptions debug_options = GetDebugOptionsFromFlags();
  Compiler::GpuTargetConfig target_config(stream_executor);

  auto& registry = stream_executor::PlatformObjectRegistry::GetGlobalRegistry();
  TF_ASSIGN_OR_RETURN(const GetCodegenBackends::Type& get_codegen_backends,
                      registry.FindObject<GetCodegenBackends>(platform->id()));
  std::vector<std::unique_ptr<CodegenBackend>> backends =
      get_codegen_backends(stream_executor, &debug_options, compiler.get(),
                           &target_config, mlir_context);

  std::unique_ptr<se::DeviceAddressAllocator> allocator =
      std::make_unique<stream_executor::StreamExecutorAddressAllocator>(
          stream_executor);
  auto profiler =
      GpuProfiler::Create(stream_executor, ProfileOptions(), allocator.get());
  if (profiler == nullptr) {
    return absl::InternalError("Failed to create profiler");
  }

  tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "autotuner",
                                      tsl::port::MaxParallelism());

  const absl::flat_hash_map<std::string, DebugOptions::AutotuneCacheMode>
      mode_map = {
          {"READ_WRITE", DebugOptions::AUTOTUNE_CACHE_MODE_UPDATE},
          {"READ", DebugOptions::AUTOTUNE_CACHE_MODE_READ},
      };
  auto it = mode_map.find(autotune_cache_mode_str);
  if (it == mode_map.end()) {
    return absl::InvalidArgumentError(
        absl::StrCat("Invalid autotune_cache_mode: ", autotune_cache_mode_str));
  }

  std::unique_ptr<AutotunerCacheInterface> cache;
  if (!cache_dir.empty()) {
    cache = std::make_unique<LegacyCache>(cache_dir, it->second,
                                          target_config.device_description);
  }

  AutotuneConfig autotune_config;
  TF_ASSIGN_OR_RETURN(
      std::unique_ptr<Autotuner> autotuner,
      Autotuner::Create(std::move(backends), std::move(profiler),
                        autotune_config, std::move(cache), &thread_pool));

  // TODO: b/407494793 - Expand the filter to include more instructions.
  auto should_autotune = [](const HloInstruction& instruction) -> bool {
    if ((instruction.opcode() == HloOpcode::kFusion &&
         instruction.fusion_kind() == HloInstruction::FusionKind::kCustom) ||
        instruction.opcode() == HloOpcode::kCustomCall) {
      return true;
    }
    return false;
  };

  return autotuner->Autotune(&module, should_autotune);
}

}  // namespace
}  // namespace gpu
}  // namespace xla

int main(int argc, char* argv[]) {
  std::string hlo_file;
  std::string cache_dir;
  std::string autotune_cache_mode = "READ_WRITE";
  std::vector<tsl::Flag> flag_list = {
      tsl::Flag("hlo_file", &hlo_file, "Path to the HLO file to autotune."),
      tsl::Flag("cache_dir", &cache_dir,
                "Directory to store/load the autotune cache."),
      tsl::Flag("autotune_cache_mode", &autotune_cache_mode,
                "Autotune cache mode: READ or READ_WRITE.")};

  const std::string usage_string =
      absl::StrCat(kUsage, "\n\n", tsl::Flags::Usage(argv[0], flag_list));
  bool parse_ok = tsl::Flags::Parse(&argc, argv, flag_list);
  if (!parse_ok) {
    LOG(QFATAL) << usage_string;
  }
  tsl::port::InitMain(usage_string.c_str(), &argc, &argv);
  auto module = xla::gpu::GetModule(hlo_file);
  CHECK_OK(module.status());
  mlir::MLIRContext mlir_context;
  xla::RegisterSymbolicExprStorage(&mlir_context);
  CHECK_OK(xla::gpu::Autotune(*module.value(), cache_dir, autotune_cache_mode,
                              &mlir_context));
  std::cout << module.value()->ToString() << std::endl;
  return 0;
}
