// Copyright 2025 The OpenXLA Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "xla/hlo/tools/hlo_diff/hlo_diff_summary.h"

#include <sys/stat.h>
#include <sys/types.h>

#include <algorithm>
#include <cstdint>
#include <memory>
#include <optional>
#include <ostream>
#include <string>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/functional/bind_front.h"
#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h"
#include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph_node.h"
#include "xla/hlo/tools/hlo_diff/graph/utils/hlo_gumgraph_dfs.h"
#include "xla/hlo/tools/hlo_diff/hlo_diff_result.h"
#include "xla/hlo/tools/hlo_diff/hlo_gumgraph_mappings.h"
#include "xla/hlo/tools/hlo_diff/proto/diff_result.pb.h"
#include "xla/hlo/tools/hlo_diff/utils/bidirectional_map.h"
#include "xla/hlo/tools/hlo_diff/utils/connected_components.h"
#include "xla/tsl/platform/statusor.h"
#include "tsl/platform/fingerprint.h"

namespace xla {
namespace hlo_diff {
namespace {

using InstructionBimap =
    BidirectionalMap<const HloInstruction*, const HloInstruction*,
                     HloInstructionNodeMappingProps>;

InstructionBimap ConstructInstructionBimap(const DiffResult& diff_result) {
  InstructionBimap mapping;
  for (const auto& [left, right] : diff_result.unchanged_instructions) {
    mapping.Insert(left, right);
  }
  for (const auto& [left, right] : diff_result.changed_instructions) {
    mapping.Insert(left, right);
  }
  return mapping;
}

// Returns the mapped instruction node of the given instruction in the given
// direction. Returns nullptr if the instruction is not mapped.
std::optional<const HloInstruction*> FindMappedInstruction(
    const InstructionBimap& mapping, const HloInstruction* instruction,
    DiffSide side) {
  switch (side) {
    case DiffSide::kLeft: {
      return mapping.GetRight(instruction);
      break;
    }
    case DiffSide::kRight: {
      return mapping.GetLeft(instruction);
      break;
    }
  }

  return std::nullopt;
}

// Result of finding the main matched computation.
struct MainMatchedComputationResult {
  const HloComputation* main_matched_computation = nullptr;
  int max_matched_instruction_count = 0;
  int split_allegiance_instruction_count = 0;
};

// Returns the main matched computation of the given computation in the given
// direction. A computation is considered as the main matched computation if it
// has the most matched instructions.
MainMatchedComputationResult FindMainMatchedComputation(
    const HloComputation* computation, const InstructionBimap& mapping,
    DiffSide side) {
  absl::flat_hash_map<const HloComputation*, int> matched_instruction_count;
  int max_count = 0;
  int mapped_instruction_count = 0;
  const HloComputation* main_matched_computation = nullptr;
  for (const HloInstruction* instruction : computation->instructions()) {
    if (std::optional<const HloInstruction*> mapped_instruction =
            FindMappedInstruction(mapping, instruction, side);
        mapped_instruction.has_value()) {
      ++mapped_instruction_count;
      const HloComputation* right_computation = (*mapped_instruction)->parent();
      const int count = ++matched_instruction_count[right_computation];
      if (count > max_count) {
        max_count = count;
        main_matched_computation = right_computation;
      }
    }
  }
  MainMatchedComputationResult result;
  result.main_matched_computation = main_matched_computation;
  result.max_matched_instruction_count = max_count;
  result.split_allegiance_instruction_count =
      mapped_instruction_count - max_count;
  return result;
}

struct DiffFingerprint {
  bool all_unchanged;
  uint64_t diff_fingerprint;
};

DiffFingerprint ComputationDiffFingerprint(
    const xla::HloComputation* computation,
    const absl::flat_hash_map<const HloInstruction*, DiffType>& diff_codes) {
  absl::flat_hash_map<const HloInstruction*, uint64_t> subgraph_fingerprint;
  bool all_unchanged = true;
  for (auto* instruction : computation->MakeInstructionPostOrder()) {
    uint64_t fp = static_cast<uint64_t>(instruction->opcode());
    uint64_t diff_type_fp = DiffType::kUnchanged;
    if (auto it = diff_codes.find(instruction); it != diff_codes.end()) {
      diff_type_fp = it->second;
    }
    all_unchanged = all_unchanged && (diff_type_fp == DiffType::kUnchanged);
    fp = tsl::FingerprintCat64(fp, diff_type_fp);
    for (const HloInstruction* operand : instruction->operands()) {
      fp = tsl::FingerprintCat64(fp, subgraph_fingerprint.at(operand));
    }
    // TODO(b/394201811): Make sure no fingerprint collision.
    subgraph_fingerprint[instruction] = fp;
  }
  DiffFingerprint result;
  result.all_unchanged = all_unchanged;
  result.diff_fingerprint =
      subgraph_fingerprint.at(computation->root_instruction());
  return result;
}

// Split the computations into left and right computations.
ComputationGroup SplitComputations(
    const std::vector<const HloComputation*>& computations,
    const ComputationSummaryMap& computation_summaries) {
  ComputationGroup result;
  for (const HloComputation* computation : computations) {
    if (auto it = computation_summaries.find(computation);
        it != computation_summaries.end()) {
      if (it->second.side == DiffSide::kLeft) {
        result.left_computations.push_back(computation);
      } else {
        result.right_computations.push_back(computation);
      }
    }
  }
  return result;
}

// Returns the connected components of the given computation summary.
absl::flat_hash_map<uint64_t, std::vector<ComputationGroup>>
FindConnectedComponents(const ComputationSummaryMap& computation_summary) {
  ConnectedComponentsFinder cc;
  std::vector<std::vector<const HloComputation*>> unmatched_computations;
  absl::flat_hash_map<uint64_t, std::vector<ComputationGroup>> result;
  for (const auto& [computation, computation_match_info] :
       computation_summary) {
    if (computation_match_info.main_matched_computation != nullptr) {
      cc.AddEdge(computation, computation_match_info.main_matched_computation);
    } else {
      // main_matched_computation is nullptr means all instructions in the
      // computation are unmatched.
      unmatched_computations.push_back({computation});
    }
  }
  std::vector<std::vector<const HloComputation*>> connected_component_groups =
      cc.FindConnectedComponents();
  connected_component_groups.insert(connected_component_groups.end(),
                                    unmatched_computations.begin(),
                                    unmatched_computations.end());

  for (const auto& component_group : connected_component_groups) {
    bool all_unchanged = true;
    for (const auto& computation : component_group) {
      all_unchanged =
          all_unchanged && computation_summary.at(computation).all_unchanged;
    }
    // Skip the component group if all computations are unchanged.
    if (all_unchanged) {
      continue;
    }
    std::vector<const HloComputation*> sorted_component_group(component_group);
    std::sort(sorted_component_group.begin(), sorted_component_group.end(),
              [&](const HloComputation* a, const HloComputation* b) {
                return computation_summary.at(a).diff_fingerprint <
                       computation_summary.at(b).diff_fingerprint;
              });
    uint64_t fingerprint = 0;
    for (const auto& computation : sorted_component_group) {
      fingerprint = tsl::FingerprintCat64(
          fingerprint, computation_summary.at(computation).diff_fingerprint);
    }
    result[fingerprint].push_back(
        SplitComputations(sorted_component_group, computation_summary));
  }
  return result;
}

DiffMetrics GetDiffMetrics(const ComputationGroup& computation_group,
                           const DiffResult& diff_result) {
  DiffMetrics result;
  for (const HloComputation* computation :
       computation_group.left_computations) {
    for (const HloInstruction* instruction : computation->instructions()) {
      if (diff_result.changed_instructions.contains(instruction)) {
        ++result.changed_instruction_count;
      } else if (diff_result.left_module_unmatched_instructions.contains(
                     instruction)) {
        ++result.left_unmatched_instruction_count;
      }
    }
  }
  for (const HloComputation* computation :
       computation_group.right_computations) {
    for (const HloInstruction* instruction : computation->instructions()) {
      if (diff_result.changed_instructions.contains(instruction)) {
        ++result.changed_instruction_count;
      } else if (diff_result.right_module_unmatched_instructions.contains(
                     instruction)) {
        ++result.right_unmatched_instruction_count;
      }
    }
  }
  return result;
}

std::vector<ComputationDiffPattern> FindComputationDiffPatterns(
    const ComputationSummaryMap& computation_summary,
    const DiffResult& diff_result) {
  std::vector<ComputationDiffPattern> result;
  absl::flat_hash_map<uint64_t, std::vector<ComputationGroup>>
      connected_components = FindConnectedComponents(computation_summary);
  for (const auto& [fingerprint, computation_groups] : connected_components) {
    ComputationDiffPattern diff_pattern;
    diff_pattern.fingerprint = fingerprint;
    diff_pattern.computation_groups = computation_groups;
    diff_pattern.diff_metrics =
        GetDiffMetrics(computation_groups[0], diff_result);
    result.push_back(std::move(diff_pattern));
  }
  return result;
}

// Summarizes all computations in the given graph.
ComputationSummaryMap SummarizeAllComputationsInGraph(
    const HloModule& module, const InstructionBimap& mapping,
    const absl::flat_hash_map<const HloInstruction*, DiffType>& diff_codes,
    DiffSide side) {
  ComputationSummaryMap result;
  for (const HloComputation* computation : module.computations()) {
    const MainMatchedComputationResult mmc =
        FindMainMatchedComputation(computation, mapping, side);
    DiffFingerprint dfp = ComputationDiffFingerprint(computation, diff_codes);
    ComputationSummary summary;
    summary.side = side;
    summary.main_matched_computation = mmc.main_matched_computation;
    summary.max_matched_instruction_count = mmc.max_matched_instruction_count;
    summary.split_allegiance_instruction_count =
        mmc.split_allegiance_instruction_count;
    summary.diff_fingerprint = dfp.diff_fingerprint;
    summary.all_unchanged = dfp.all_unchanged;
    result.insert({computation, summary});
  }
  return result;
}

// Returns the computation summary.
ComputationSummaryMap GetComputationSummary(const HloModule& left_module,
                                            const HloModule& right_module,
                                            const DiffResult& diff_result) {
  ComputationSummaryMap summary;
  InstructionBimap mapping = ConstructInstructionBimap(diff_result);
  summary.merge(SummarizeAllComputationsInGraph(
      left_module, mapping, diff_result.left_diff_codes, DiffSide::kLeft));
  summary.merge(SummarizeAllComputationsInGraph(
      right_module, mapping, diff_result.right_diff_codes, DiffSide::kRight));
  return summary;
}

// Returns the instruction summary.
InstructionSummaryMap GetInstructionSummary(const HloGumgraph& left_graph,
                                            const HloGumgraph& right_graph,
                                            const DiffResult& diff_result) {
  InstructionSummaryMap summaries;
  auto instruction_summary = [&](const DiffSide side,
                                 const HloInstructionNode& node) {
    if (node.is_root) {
      return true;
    }
    InstructionSummary summary;
    summary.side = side;
    summary.subgraph_unchanged = true;
    const absl::flat_hash_map<const HloInstruction*, DiffType>& diff_codes =
        side == DiffSide::kLeft ? diff_result.left_diff_codes
                                : diff_result.right_diff_codes;
    if (auto it = diff_codes.find(node.instruction);
        it == diff_codes.end() || it->second != DiffType::kUnchanged) {
      summary.subgraph_unchanged = false;
      summaries.insert({node.instruction, std::move(summary)});
      return true;
    }
    for (const HloInstructionNode* child : node.children) {
      if (auto it = summaries.find(child->instruction);
          it != summaries.end() && !it->second.subgraph_unchanged) {
        summary.subgraph_unchanged = false;
        break;
      }
    }
    summaries.insert({node.instruction, std::move(summary)});
    return true;
  };
  HloGumgraphDfs(left_graph.GetRoot(),
                 absl::bind_front(instruction_summary, DiffSide::kLeft),
                 DfsTraversalOrder::kPostOrder, left_graph.GetNodeCount());
  HloGumgraphDfs(right_graph.GetRoot(),
                 absl::bind_front(instruction_summary, DiffSide::kRight),
                 DfsTraversalOrder::kPostOrder, right_graph.GetNodeCount());
  return summaries;
}

// Logs the computation group.
void LogComputationGroup(const ComputationGroup& computation_group) {
  std::vector<std::string> computations_str(
      computation_group.left_computations.size() +
      computation_group.right_computations.size());
  for (int i = 0; i < computation_group.left_computations.size(); ++i) {
    computations_str[i] = absl::StrFormat(
        "L: %s", computation_group.left_computations[i]->name());
  }
  for (int i = 0; i < computation_group.right_computations.size(); ++i) {
    computations_str[computation_group.left_computations.size() + i] =
        absl::StrFormat("R: %s",
                        computation_group.right_computations[i]->name());
  }
  LOG(INFO) << absl::StrJoin(computations_str, ", ");
}

// Logs the computation diff pattern.
void LogComputationDiffPattern(const ComputationDiffPattern& diff_pattern) {
  LOG(INFO) << diff_pattern.computation_groups.size()
            << " Repeated Diff Pattern Fingerprint: "
            << diff_pattern.fingerprint;
  int i = 0;
  for (const auto& computation_group : diff_pattern.computation_groups) {
    ++i;
    LogComputationGroup(computation_group);
    if (i >= 5) {
      LOG(INFO) << "...";
      break;
    }
  }
}

}  // namespace

std::unique_ptr<const DiffSummary> ConstructDiffSummary(
    const HloGumgraph& left_graph, const HloGumgraph& right_graph,
    const DiffResult& diff_result) {
  auto summary = std::make_unique<DiffSummary>();

  // Summarize the computations.
  summary->computation_summary = GetComputationSummary(
      left_graph.GetHloModule(), right_graph.GetHloModule(), diff_result);

  // Group the computations by their diff fingerprint.
  summary->computation_diff_patterns =
      FindComputationDiffPatterns(summary->computation_summary, diff_result);

  // Summarize the instructions.
  summary->instruction_summary =
      GetInstructionSummary(left_graph, right_graph, diff_result);

  return summary;
}

absl::StatusOr<std::unique_ptr<const DiffSummary>> ConstructDiffSummary(
    const HloModule& left_module, const HloModule& right_module,
    const DiffResult& diff_result) {
  TF_ASSIGN_OR_RETURN(std::unique_ptr<const HloGumgraph> graph_l,
                      HloGumgraph::Create(&left_module));
  TF_ASSIGN_OR_RETURN(std::unique_ptr<const HloGumgraph> graph_r,
                      HloGumgraph::Create(&right_module));
  return ConstructDiffSummary(*graph_l, *graph_r, diff_result);
}

void LogDiffSummary(const DiffSummary& diff_summary) {
  // Log the connected components repeated more than 3 times.
  LOG(INFO) << "Diff Summary: ";

  // Log the computation diff patterns.
  if (diff_summary.computation_diff_patterns.empty()) {
    LOG(INFO) << "No diff patterns found.";
  } else {
    LOG(INFO) << "Found Repeated Diff Patterns: ";
    for (const ComputationDiffPattern& diff_pattern :
         diff_summary.computation_diff_patterns) {
      // Only log the patterns with at least 3 repeats.
      if (diff_pattern.computation_groups.size() < 3) {
        continue;
      }
      LogComputationDiffPattern(diff_pattern);
    }
  }
}

void PrintTo(const ComputationDiffPattern& diff_pattern, std::ostream* os) {
  *os << "{ fingerprint: " << diff_pattern.fingerprint;
  for (const auto& computation_group : diff_pattern.computation_groups) {
    *os << ", computation_groups: "
        << "{ L: ";
    for (const HloComputation* computation :
         computation_group.left_computations) {
      *os << absl::StrFormat("%s ", computation->name());
    }
    *os << ", R: ";
    for (const HloComputation* computation :
         computation_group.right_computations) {
      *os << absl::StrFormat("%s ", computation->name());
    }
    *os << " }";
  }
  *os << ", diff_metrics: {"
      << diff_pattern.diff_metrics.changed_instruction_count << " changed, "
      << diff_pattern.diff_metrics.left_unmatched_instruction_count
      << " left unmatched, "
      << diff_pattern.diff_metrics.right_unmatched_instruction_count
      << " right unmatched }";
  *os << " }";
}

DiffSummaryProto DiffSummary::ToProto() const {
  DiffSummaryProto proto;
  for (const auto& diff_pattern : computation_diff_patterns) {
    ComputationDiffPatternProto* diff_pattern_proto =
        proto.add_computation_diff_patterns();
    diff_pattern_proto->set_fingerprint(diff_pattern.fingerprint);
    diff_pattern_proto->set_changed_instruction_count(
        diff_pattern.diff_metrics.changed_instruction_count);
    diff_pattern_proto->set_left_unmatched_instruction_count(
        diff_pattern.diff_metrics.left_unmatched_instruction_count);
    diff_pattern_proto->set_right_unmatched_instruction_count(
        diff_pattern.diff_metrics.right_unmatched_instruction_count);
    for (const auto& computation_group : diff_pattern.computation_groups) {
      ComputationGroupProto* group_proto =
          diff_pattern_proto->add_computation_group();
      for (const HloComputation* computation :
           computation_group.left_computations) {
        ComputationDetailsProto* details_proto =
            group_proto->add_left_computations();
        details_proto->set_name(computation->name());
        for (const HloInstruction* instruction : computation->instructions()) {
          details_proto->add_instructions(instruction->name());
        }
      }
      for (const HloComputation* computation :
           computation_group.right_computations) {
        ComputationDetailsProto* details_proto =
            group_proto->add_right_computations();
        details_proto->set_name(computation->name());
        for (const HloInstruction* instruction : computation->instructions()) {
          details_proto->add_instructions(instruction->name());
        }
      }
    }
  }
  return proto;
}

DiffSummary DiffSummary::FromProto(const DiffSummaryProto& proto,
                                   const HloModule& left_module,
                                   const HloModule& right_module) {
  DiffSummary summary;
  absl::flat_hash_map<std::string, const HloComputation*> left_computation_map;
  for (const HloComputation* computation : left_module.computations()) {
    left_computation_map[computation->name()] = computation;
  }
  absl::flat_hash_map<std::string, const HloComputation*> right_computation_map;
  for (const HloComputation* computation : right_module.computations()) {
    right_computation_map[computation->name()] = computation;
  }

  for (const auto& diff_pattern_proto : proto.computation_diff_patterns()) {
    ComputationDiffPattern diff_pattern;
    diff_pattern.fingerprint = diff_pattern_proto.fingerprint();
    diff_pattern.diff_metrics.changed_instruction_count =
        diff_pattern_proto.changed_instruction_count();
    diff_pattern.diff_metrics.left_unmatched_instruction_count =
        diff_pattern_proto.left_unmatched_instruction_count();
    diff_pattern.diff_metrics.right_unmatched_instruction_count =
        diff_pattern_proto.right_unmatched_instruction_count();
    for (const auto& group_proto : diff_pattern_proto.computation_group()) {
      ComputationGroup group;
      for (const auto& computation_details : group_proto.left_computations()) {
        group.left_computations.push_back(
            left_computation_map.at(computation_details.name()));
      }
      for (const auto& computation_details : group_proto.right_computations()) {
        group.right_computations.push_back(
            right_computation_map.at(computation_details.name()));
      }
      diff_pattern.computation_groups.push_back(group);
    }
    summary.computation_diff_patterns.push_back(diff_pattern);
  }
  return summary;
}

}  // namespace hlo_diff
}  // namespace xla
