/* Copyright 2017 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"

#include <algorithm>
#include <cstdint>
#include <functional>
#include <iterator>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "xla/debug_options_flags.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_print_options.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/hlo/pass/hlo_pass_interface.h"
#include "xla/hlo/testlib/filecheck.h"
#include "xla/hlo/testlib/verified_hlo_module.h"
#include "xla/hlo/utils/hlo_query.h"
#include "xla/service/computation_placer.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/hlo_verifier.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h"
#include "xla/util.h"
#include "xla/xla.pb.h"
#include "xla/xla_data.pb.h"

namespace xla {

HloHardwareIndependentTestBase::HloHardwareIndependentTestBase(
    bool verifier_layout_sensitive, bool allow_mixed_precision_in_hlo_verifier,
    HloPredicate instruction_can_change_layout_func,
    bool verify_no_collective_deadlocks)
    : verifier_layout_sensitive_(verifier_layout_sensitive),
      allow_mixed_precision_in_hlo_verifier_(
          allow_mixed_precision_in_hlo_verifier),
      instruction_can_change_layout_func_(instruction_can_change_layout_func) {
  hlo_verifier_ = std::make_unique<HloVerifier>(
      /*layout_sensitive=*/verifier_layout_sensitive,
      /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier,
      instruction_can_change_layout_func,
      [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); },
      verify_no_collective_deadlocks);
}

HloHardwareIndependentTestBase::HloHardwareIndependentTestBase(
    HloVerifierOpts&& verifier_opts) {
  hlo_verifier_ = std::make_unique<HloVerifier>(std::move(verifier_opts));
}

std::unique_ptr<HloModule>
HloHardwareIndependentTestBase::CreateNewUnverifiedModule(
    const std::string& name) const {
  return std::make_unique<HloModule>(name, GetModuleConfigForTest());
}

std::unique_ptr<VerifiedHloModule>
HloHardwareIndependentTestBase::CreateNewVerifiedModule(
    const std::string& name, int64_t replica_count) const {
  return std::make_unique<VerifiedHloModule>(
      name, GetModuleConfigForTest(replica_count), verifier_layout_sensitive_,
      allow_mixed_precision_in_hlo_verifier_, ShapeUtil::ByteSizeOfElements,
      instruction_can_change_layout_func_);
}

/* static */ DeviceAssignment
HloHardwareIndependentTestBase::GetDefaultDeviceAssignment(
    int64_t replica_count, int64_t num_partitions) {
  DeviceAssignment device_assignment(replica_count, num_partitions);
  device_assignment.FillIota(0);
  return device_assignment;
}

absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule(
    absl::string_view hlo_text, int64_t replica_count, int64_t num_partitions,
    std::optional<DeviceAssignment> device_assignment) const {
  HloModuleConfig config =
      GetModuleConfigForTest(replica_count, num_partitions);
  if (device_assignment.has_value()) {
    config.set_static_device_assignment(device_assignment.value());
  }
  return ParseAndReturnVerifiedModule(hlo_text, config);
}

absl::Status HloHardwareIndependentTestBase::
    UpdateEntryComputationLayoutToMatchProgramLayout(HloModule* module) {
  for (auto* const computation : module->computations({})) {
    if (computation->IsEntryComputation()) {
      for (int64_t i = 0; i < computation->num_parameters(); ++i) {
        const Shape& param_shape =
            computation->parameter_instruction(i)->shape();
        TF_RETURN_IF_ERROR(computation->parent()
                               ->mutable_entry_computation_layout()
                               ->mutable_parameter_layout(i)
                               ->CopyLayoutFromShape(param_shape));
      }

      TF_RETURN_IF_ERROR(
          computation->parent()
              ->mutable_entry_computation_layout()
              ->mutable_result_layout()
              ->CopyLayoutFromShape(computation->root_instruction()->shape()));
    }
  }
  return absl::OkStatus();
}

absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule(
    absl::string_view hlo_text, const HloModuleConfig& config,
    const HloParserOptions& parser_options) const {
  return ParseAndReturnVerifiedModule(hlo_text, config, parser_options,
                                      ShapeUtil::ByteSizeOfElements);
}

absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule(
    absl::string_view hlo_text, const HloModuleConfig& config,
    const HloParserOptions& parser_options,
    std::function<int64_t(const xla::Shape&)> shape_size_fn) const {
  HloModuleConfig config_with_device_assignment = config;
  if (!config.has_static_device_assignment()) {
    absl::MutexLock ml(device_assignment_mu_);
    default_device_assignment_ =
        std::make_unique<DeviceAssignment>(GetDefaultDeviceAssignment(
            config.replica_count(), config.num_partitions()));
    config_with_device_assignment.set_static_device_assignment(
        *default_device_assignment_);
  }
  auto module = std::make_unique<VerifiedHloModule>(
      TestName(), config_with_device_assignment, verifier_layout_sensitive_,
      allow_mixed_precision_in_hlo_verifier_, shape_size_fn,
      instruction_can_change_layout_func_);
  TF_RETURN_IF_ERROR(
      module->ParseHloStringAndVerifyModule(hlo_text, parser_options));
  return module;
}

/* static */
absl::StatusOr<bool> HloHardwareIndependentTestBase::RunHloPass(
    HloPassInterface* hlo_pass, HloModule* module) {
  const std::string before_run = module->ToProto().ShortDebugString();
  TF_ASSIGN_OR_RETURN(bool changed, hlo_pass->Run(module));
  const std::string after_run = module->ToProto().ShortDebugString();
  if (changed) {
    EXPECT_NE(after_run, before_run) << absl::StrFormat(
        "HLO pass %s claims to have changed the module, but the module remains "
        "the same.",
        hlo_pass->name());
  } else {
    EXPECT_EQ(after_run, before_run) << absl::StrFormat(
        "HLO pass %s claims to have not changed the module, but the module is "
        "different after the pass.",
        hlo_pass->name());
  }
  return changed;
}

/* static */
PrecisionConfig HloHardwareIndependentTestBase::DefaultPrecisionConfig(
    int operands) {
  PrecisionConfig precision_config;
  precision_config.mutable_operand_precision()->Resize(
      operands, PrecisionConfig::DEFAULT);
  return precision_config;
}

void HloHardwareIndependentTestBase::SetAotFastMathDebugOptions(
    DebugOptions* options) {
  options->set_xla_cpu_enable_fast_math(true);
  options->set_xla_gpu_enable_fast_min_max(true);
  options->set_xla_cpu_enable_fast_min_max(true);
  options->set_xla_cpu_fast_math_honor_nans(false);
  options->set_xla_cpu_fast_math_honor_infs(false);
  options->set_xla_cpu_fast_math_honor_functions(false);
  options->set_xla_cpu_fast_math_honor_division(false);
}

DebugOptions HloHardwareIndependentTestBase::GetDebugOptionsForTest() const {
  auto debug_options = GetDebugOptionsFromFlags();
  // TODO(b/38354253): Change tests to use Parameters instead of Constants.
  debug_options.add_xla_disable_hlo_passes("constant_folding");
  debug_options.set_xla_hlo_evaluator_use_fast_path(true);
  debug_options.set_xla_cpu_emitter_verification_level(1);
  return debug_options;
}

void HloHardwareIndependentTestBase::RunAndFilecheckHloRewrite(
    absl::string_view hlo, HloPassInterface&& hlo_pass,
    std::optional<absl::string_view> expected,
    std::function<void(HloModule*)> after_pass_checks,
    const HloModuleConfig* config,
    absl::Span<const absl::string_view> additional_check_prefixes) const {
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                          config ? ParseAndReturnVerifiedModule(hlo, *config)
                                 : ParseAndReturnVerifiedModule(hlo));
  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&hlo_pass, module.get()));
  EXPECT_EQ(changed, expected.has_value()) << module->ToString();
  if (changed) {
    TF_ASSERT_OK_AND_ASSIGN(
        bool filecheck_matches,
        RunFileCheck(
            module->ToString(HloPrintOptions().set_print_large_constants(true)),
            *expected, additional_check_prefixes));
    EXPECT_TRUE(filecheck_matches) << module->ToString();
    if (after_pass_checks) {
      after_pass_checks(module.get());
    }
  }
}

void HloHardwareIndependentTestBase::RunAndFilecheckHloRewrite(
    absl::string_view hlo_with_checks, HloPassInterface&& hlo_pass,
    std::function<void(HloModule*)> after_pass_checks,
    const HloModuleConfig* config) const {
  RunAndFilecheckHloRewrite(hlo_with_checks, std::move(hlo_pass),
                            hlo_with_checks, after_pass_checks, config);
}

absl::StatusOr<std::unique_ptr<HloModule>>
HloHardwareIndependentTestBase::RunAndCheckHloRewrite(
    absl::string_view hlo_template, HloPassInterface* hlo_pass,
    bool expect_change, FixedMapping params) const {
  std::string hlo_string = absl::StrReplaceAll(hlo_template, params);
  SCOPED_TRACE("Input HLO: " + hlo_string);
  VLOG(7) << "Input HLO: " << hlo_string;
  TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
                      ParseAndReturnVerifiedModule(hlo_string));
  VLOG(7) << "Input HLO parsed. Running the pass:  + " << hlo_pass->name();
  TF_ASSIGN_OR_RETURN(bool changed, RunHloPass(hlo_pass, module.get()));
  VLOG(7) << "Output HLO: "
          << module->ToString(HloPrintOptions::ShortParsable()
                                  .set_print_control_dependencies(true));
  EXPECT_EQ(changed, expect_change);
  return module;
}

std::vector<int> HloHardwareIndependentTestBase::CompareInputs(
    const HloModule& module_0, const HloModule& module_1) {
  const auto params_0 = module_0.entry_computation()->parameter_instructions();
  const auto params_1 = module_1.entry_computation()->parameter_instructions();
  std::vector<int> mismatches;
  int64_t min = std::min(params_0.size(), params_1.size());
  int64_t max = std::max(params_0.size(), params_1.size());
  for (int64_t i = 0; i < min; ++i) {
    const HloModuleConfig& module_config_0 = module_0.config();
    const Shape& param_shape_0 =
        (module_config_0.has_entry_computation_layout() &&
         module_config_0.entry_computation_layout()
             .parameter_layout(i)
             .shape()
             .is_static())
            ? module_config_0.entry_computation_layout()
                  .parameter_layout(i)
                  .shape()
            : params_0[i]->shape();

    const HloModuleConfig& module_config_1 = module_1.config();
    const Shape& param_shape_1 =
        (module_config_1.has_entry_computation_layout() &&
         module_config_1.entry_computation_layout()
             .parameter_layout(i)
             .shape()
             .is_static())
            ? module_config_1.entry_computation_layout()
                  .parameter_layout(i)
                  .shape()
            : params_1[i]->shape();

    if (!Shape::Equal().IgnoreTilesInLayout()(param_shape_0, param_shape_1)) {
      mismatches.push_back(i);
    }
  }
  for (int64_t i = min; i < max; i++) {
    mismatches.push_back(i);
  }
  return mismatches;
}

/* static */
HloComputation* HloHardwareIndependentTestBase::FindComputation(
    HloModule* module, absl::string_view name) {
  return hlo_query::FindComputation(module, name);
}

/* static */
HloInstruction* HloHardwareIndependentTestBase::FindInstruction(
    const HloModule* module, absl::string_view name) {
  for (const HloComputation* computation : module->computations()) {
    if (HloInstruction* instruction =
            hlo_query::FindInstruction(computation, name)) {
      return instruction;
    }
  }
  return nullptr;
}

/* static */
HloInstruction* HloHardwareIndependentTestBase::FindInstruction(
    const HloModule* module, HloOpcode opcode) {
  for (const HloComputation* computation : module->computations()) {
    if (HloInstruction* instruction =
            hlo_query::FindInstruction(computation, opcode)) {
      return instruction;
    }
  }
  return nullptr;
}

/* static */
std::vector<HloInstruction*> HloHardwareIndependentTestBase::FindInstructions(
    const HloModule* module, HloOpcode opcode) {
  std::vector<HloInstruction*> instructions;
  for (const HloComputation* c : module->computations()) {
    absl::c_copy_if(c->instructions(), std::back_inserter(instructions),
                    [&](HloInstruction* i) { return i->opcode() == opcode; });
  }
  return instructions;
}

}  // namespace xla
