/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/backends/gpu/codegen/emitters/scatter.h"

#include <algorithm>
#include <cassert>
#include <cstdint>
#include <memory>
#include <numeric>
#include <optional>
#include <utility>
#include <vector>

#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/types/span.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h"
#include "xla/backends/gpu/codegen/fusion_emitter.h"
#include "xla/codegen/emitters/computation_partitioner.h"
#include "xla/codegen/emitters/elemental_hlo_to_mlir.h"
#include "xla/codegen/emitters/ir/xla_ops.h"
#include "xla/codegen/emitters/type_util.h"
#include "xla/codegen/emitters/utils.h"
#include "xla/hlo/analysis/indexing_analysis.h"
#include "xla/hlo/analysis/indexing_map.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/primitive_util.h"
#include "xla/service/gpu/gpu_fusible.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/scatter_simplifier.h"
#include "xla/shape.h"
#include "xla/stream_executor/device_description.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"

namespace xla {
namespace gpu {
namespace {

namespace arith = ::mlir::arith;
namespace scf = ::mlir::scf;
namespace vector = ::mlir::vector;
namespace tensor = ::mlir::tensor;

using emitters::CallTargetProvider;
using emitters::EmitXlaLoopOp;
using emitters::PartitionedComputations;
using emitters::ProvideParameter;
using llvm::APFloat;
using llvm::APInt;
using llvm::ArrayRef;
using llvm::SmallVector;
using mlir::AffineExpr;
using mlir::AffineMap;
using mlir::DenseElementsAttr;
using mlir::getAffineDimExpr;
using mlir::getAffineSymbolExpr;
using mlir::ImplicitLocOpBuilder;
using mlir::Location;
using mlir::MLIRContext;
using mlir::OpBuilder;
using mlir::Value;
using mlir::ValueRange;
using mlir::VectorType;
using mlir::func::FuncOp;
using mlir::func::ReturnOp;
using primitive_util::IsUnsignedIntegralType;

constexpr int64_t kNumWarpsPerBlock = 4;
constexpr int64_t kMaxVectorizedBits = 128;
constexpr int64_t kScatterOperandIndex = 0;
constexpr int64_t kScatterIndicesIndex = 1;
constexpr int64_t kScatterUpdateIndex = 2;

// Emit
// if (condition) {
//   updated_values = updated_values_fn();
//   yield updated_values;
// } else {
//   yield values;
// }
ValueRange EmitUpdateIf(
    ImplicitLocOpBuilder& b, Value condition, ValueRange values,
    llvm::function_ref<SmallVector<Value>(ImplicitLocOpBuilder&)>
        updated_values_fn) {
  return scf::IfOp::create(
             b, condition,
             [&](OpBuilder& then_b, Location then_loc) -> void {
               ImplicitLocOpBuilder implicit_then_b(then_loc, then_b);
               scf::YieldOp::create(then_b, then_loc,
                                    updated_values_fn(implicit_then_b));
             },
             [&](OpBuilder& else_b, Location else_loc) -> void {
               scf::YieldOp::create(else_b, else_loc, values);
             })
      .getResults();
}

// Computes if the slice with the sizes `slice_shape` and the offsets `offsets`
// can be inserted into the operand with the shape `operand_shape`.
Value EmitBoundsCheck(ImplicitLocOpBuilder& b,
                      absl::Span<const int64_t> slice_shape,
                      absl::Span<const int64_t> operand_shape,
                      ValueRange offsets) {
  Value in_bounds = arith::ConstantIntOp::create(b, b.getI1Type(), 1);
  for (auto [update_dim, operand_dim, offset] :
       llvm::zip(slice_shape, operand_shape, offsets)) {
    Value ub = arith::ConstantIndexOp::create(b, operand_dim - update_dim);
    // One bounds check is enough even for signed indices: `sge 0` is
    // implied by `ule ub`, because `ub >= 0`.
    in_bounds = b.createOrFold<arith::AndIOp>(
        in_bounds,
        b.createOrFold<arith::CmpIOp>(arith::CmpIPredicate::ule, offset, ub));
  }
  return in_bounds;
}

Value EmitInequalityCheck(ImplicitLocOpBuilder& b, ValueRange lhs,
                          ValueRange rhs) {
  Value not_equal = arith::ConstantIntOp::create(b, b.getI1Type(), 0);
  for (auto [lhs_elem, rhs_elem] : llvm::zip(lhs, rhs)) {
    not_equal = b.createOrFold<arith::OrIOp>(
        not_equal, b.createOrFold<arith::CmpIOp>(arith::CmpIPredicate::ne,
                                                 lhs_elem, rhs_elem));
  }
  return not_equal;
}

Value UpdateIsInbounds(ImplicitLocOpBuilder& b, Value is_inbounds,
                       Value offsets_changed, ValueRange offsets,
                       absl::Span<const int64_t> slice_shape,
                       absl::Span<const int64_t> operand_shape) {
  return EmitUpdateIf(b, offsets_changed, is_inbounds,
                      [&](ImplicitLocOpBuilder& if_b) -> SmallVector<Value> {
                        return {EmitBoundsCheck(if_b, slice_shape,
                                                operand_shape, offsets)};
                      })
      .front();
}

SmallVector<Value> Pack(ArrayRef<ValueRange> ranges) {
  int64_t total_size = 0;
  for (auto& range : ranges) {
    total_size += range.size();
  }
  SmallVector<Value> result;
  result.reserve(total_size);
  for (auto range : ranges) {
    result.append(range.begin(), range.end());
  }
  return result;
}

SmallVector<ValueRange> Unpack(ValueRange range, ArrayRef<int64_t> sizes) {
  int64_t total_size = 0;
  for (auto& size : sizes) {
    total_size += size;
  }
  assert(total_size == range.size());
  SmallVector<ValueRange> result;
  result.reserve(sizes.size());
  for (int64_t size : sizes) {
    result.push_back(range.take_front(size));
    range = range.drop_front(size);
  }
  return result;
}

// Pads the given values with zeros to the given container size.
SmallVector<Value, 4> PadWithZeros(ValueRange values, int64_t size,
                                   ImplicitLocOpBuilder& b) {
  SmallVector<Value, 4> padded_values(values.begin(), values.end());
  if (values.size() >= size) {
    return padded_values;
  }
  auto zero = arith::ConstantIndexOp::create(b, 0);
  for (int i = values.size(); i < size; ++i) {
    padded_values.push_back(zero);
  }
  return padded_values;
}

}  // namespace

class EmitterHelper {
 public:
  EmitterHelper(const ScatterDescription& description,
                const PartitionedComputations* computations,
                const CallTargetProvider* call_targets, FuncOp entry_function,
                const HloFusionInstruction& fusion)
      : description_(&description),
        entry_function_(entry_function),
        call_targets_(call_targets),
        root_computation_(&computations->FindPartitionedComputation(
            fusion.fused_instructions_computation())) {}

  Value GetOperandElement(ImplicitLocOpBuilder& b, ValueRange indices) const {
    return GetElement(b, kScatterOperandIndex, indices);
  }

  Value GetIndicesElement(ImplicitLocOpBuilder& b, ValueRange indices) const {
    return GetElement(b, kScatterIndicesIndex, indices);
  }

  Value GetUpdateElement(ImplicitLocOpBuilder& b, ValueRange indices) const {
    return GetElement(b, kScatterUpdateIndex, indices);
  }

  FuncOp GetReducer() const {
    return (*call_targets_)(
        description_->scatter->called_computations()[0]->root_instruction());
  }

  SmallVector<Value, 4> ExtractOffsets(ImplicitLocOpBuilder& b,
                                       Value slice_id) const;

  Value EmitScatterComputation(ImplicitLocOpBuilder& b, ValueRange indices,
                               Value update_elem, Value output_tensor) const;

  SmallVector<Value> WriteAccumulatedElementToOutput(
      ImplicitLocOpBuilder& b, Value accumulator,
      ValueRange accumulator_indices, ValueRange slice_indices,
      ValueRange offsets, Value output_tensor) const;

  Value WriteAccumulatorToOutput(ImplicitLocOpBuilder& b,
                                 Value write_to_output_required,
                                 ValueRange thread_and_block_ids,
                                 Value index_id,
                                 const IndexingMap& slice_indexing,
                                 ValueRange offsets, Value accumulator,
                                 Value output_tensor) const;

 private:
  Value GetElement(ImplicitLocOpBuilder& b, int operand_index,
                   ValueRange indices) const;

  const ScatterDescription* description_;
  FuncOp entry_function_;
  const emitters::CallTargetProvider* call_targets_;
  const emitters::PartitionedComputation* root_computation_;
};

SmallVector<Value, 4> EmitterHelper::ExtractOffsets(ImplicitLocOpBuilder& b,
                                                    Value slice_id) const {
  auto index_type = b.getIndexType();
  SmallVector<Value, 4> offsets;
  offsets.reserve(description_->index_vector_length);
  for (int i = 0; i < description_->index_vector_length; ++i) {
    SmallVector<Value, 4> indices_tensor_indices = {
        slice_id, arith::ConstantIndexOp::create(b, i)};
    auto index = GetIndicesElement(b, indices_tensor_indices);
    index =
        IsUnsignedIntegralType(
            description_->scatter->scatter_indices()->shape().element_type())
            ? arith::IndexCastUIOp::create(b, index_type, index).getResult()
            : arith::IndexCastOp::create(b, index_type, index).getResult();
    offsets.push_back(index);
  }
  return offsets;
}

Value EmitterHelper::EmitScatterComputation(ImplicitLocOpBuilder& b,
                                            ValueRange indices,
                                            Value update_elem,
                                            Value output_tensor) const {
  FuncOp reducer = GetReducer();
  if (description_->scatter->unique_indices()) {
    auto operand_elem = GetOperandElement(b, indices);
    auto reduced_val = emitters::InlineBlock(b, reducer.getBody().front(),
                                             {operand_elem, update_elem})[0];
    return tensor::InsertOp::create(b, reduced_val, output_tensor, indices);
  }
  auto atomic_rmw = AtomicRMWOp::create(b, output_tensor, indices);
  OpBuilder body_b = atomic_rmw.getBodyBuilder();
  auto reduced_val =
      emitters::InlineBlock(body_b, reducer.getBody().front(),
                            {atomic_rmw.getCurrentValue(), update_elem})[0];
  xla::YieldOp::create(body_b, reducer->getLoc(), reduced_val);
  return atomic_rmw->getResult(0);
}

SmallVector<Value> EmitterHelper::WriteAccumulatedElementToOutput(
    ImplicitLocOpBuilder& b, Value accumulator, ValueRange accumulator_indices,
    ValueRange slice_indices, ValueRange offsets, Value output_tensor) const {
  Value accumulator_elem = vector::ExtractOp::create(
      b, accumulator, mlir::getAsOpFoldResult(accumulator_indices));

  SmallVector<Value, 4> output_indices(offsets.begin(), offsets.end());
  for (int i = 0; i < output_indices.size(); ++i) {
    output_indices[i] =
        arith::AddIOp::create(b, slice_indices[i + 1], output_indices[i]);
  }
  return {EmitScatterComputation(b, output_indices, accumulator_elem,
                                 output_tensor)};
}

Value EmitterHelper::WriteAccumulatorToOutput(
    ImplicitLocOpBuilder& b, Value write_to_output_required,
    ValueRange thread_and_block_ids, Value index_id,
    const IndexingMap& slice_indexing, ValueRange offsets, Value accumulator,
    Value output_tensor) const {
  SmallVector<Value> dims = Pack({thread_and_block_ids, index_id});
  return EmitUpdateIf(
             b, write_to_output_required, output_tensor,
             [&](ImplicitLocOpBuilder& if_builder) -> SmallVector<Value> {
               return EmitXlaLoopOp(
                   if_builder, dims, output_tensor, slice_indexing,
                   [&](ImplicitLocOpBuilder& update_loop_b,
                       ValueRange accumulator_indices, ValueRange slice_indices,
                       ValueRange output_tensors) -> SmallVector<Value> {
                     return WriteAccumulatedElementToOutput(
                         update_loop_b, accumulator, accumulator_indices,
                         slice_indices, offsets, output_tensors.front());
                   });
             })
      .front();
}

Value EmitterHelper::GetElement(ImplicitLocOpBuilder& b, int operand_index,
                                ValueRange indices) const {
  return ProvideParameter(*root_computation_, description_->scatter,
                          operand_index, indices, *call_targets_,
                          entry_function_, b)[0];
}

ScatterFusion::ScatterFusion(const HloFusionAnalysis& analysis,
                             const ScatterDescription& description,
                             int64_t vector_size, MLIRContext* mlir_context)
    : analysis_(analysis),
      description_(description),
      mlir_context_(mlir_context),
      warp_size_(WarpSize(analysis_.device_info())),
      vector_size_(vector_size) {}

std::optional<std::vector<IndexingMap>>
ScatterFusion::ComputeThreadIdToInputIndexing(int64_t root_index,
                                              MLIRContext* ctx) const {
  CHECK(ScatterSimplifier::IsSimplifiedScatter(description_.scatter))
      << "Non-simplified HLO Scatter is not supported.";

  int64_t scatter_operand_count = description_.scatter->scatter_operand_count();
  // Scatter operands are packed in the following way:
  // Operand IDs [0, scatter_operand_count - 1] for `scatter operands`.
  // Operand ID  scatter_operand_count for `scatter indices`.
  // Operand IDs [scatter_operand_count + 1, 2 * scatter_operand_count] for
  // `scatter updates`.

  std::vector<IndexingMap> results(description_.scatter->operand_count(),
                                   IndexingMap::GetUndefined());
  // Compute the indexing for the scatter indices operand.
  ComputeIndexing(ctx, /*updates_map=*/nullptr,
                  &results[scatter_operand_count]);
  // For scatter operands we do not know the thread ID indexing.
  for (int64_t operand_index = scatter_operand_count + 1;
       operand_index < results.size(); ++operand_index) {
    ComputeIndexing(ctx, &results[operand_index], /*indices_map=*/nullptr);
  }
  return results;
}

std::vector<emitters::EpilogueSpecification> ScatterFusion::GetEpilogues(
    const HloFusionInstruction& fusion, MLIRContext* mlir_context) const {
  // We don't actually support epilogues for scatter, but this is how we tell
  // the base class that we don't want it to generate code for the scatter.
  return {emitters::EpilogueSpecification::FromIdentityIndexing(
      &analysis_.fusion_hero(0).instruction(),
      &analysis_.fusion_root(0).instruction(), mlir_context)};
}

ScatterWithDistributedUpdates::ScatterWithDistributedUpdates(
    const HloFusionAnalysis& analysis, const ScatterDescription& description,
    int64_t vector_size, MLIRContext* mlir_context)
    : ScatterFusion(analysis, description, vector_size, mlir_context) {
  // We have to make sure that there is no thread that processes elements of
  // two different update slice.
  auto launch_dimensions = CalculateLaunchDimensions(
      description_.update_shape, analysis_.device_info(),
      {static_cast<int>(vector_size_)});
  num_blocks_ = launch_dimensions.num_blocks();
  num_warps_ = CeilOfRatio(
      static_cast<int64_t>(launch_dimensions.num_threads_per_block()),
      warp_size_);
}

void ScatterWithDistributedUpdates::ComputeIndexing(
    MLIRContext* mlir_context, IndexingMap* updates_map,
    IndexingMap* indices_map) const {
  // Compute thread id mapping based on the first update operand.
  IndexingMap scatter_update_map =
      GetDefaultThreadIdIndexingMap(launch_dimensions(), vector_size_,
                                    description_.update_shape, mlir_context);

  // For scatter indices we project indexing for scatter updates and take the
  // first result of the affine map only, because they coincide.
  if (indices_map) {
    // Create a map from scatter update to scatter indices.
    *indices_map = IndexingMap{
        AffineMap::get(6, 1,
                       {scatter_update_map.GetAffineMap().getResult(0),
                        getAffineSymbolExpr(0, mlir_context)},
                       mlir_context),
        DimVarsFromGPUGrid({num_warps_ * warp_size_, 1, 1, num_blocks_, 1, 1}),
        RangeVarsFromTensorSizes({description_.index_vector_length}),
        /*rt_vars=*/{}};
    indices_map->Simplify();
  }
  if (updates_map) {
    *updates_map = std::move(scatter_update_map);
  }
}

absl::Status ScatterFusion::EmitEntryFunction(
    const PartitionedComputations& computations,
    const CallTargetProvider& call_targets, FuncOp entry_function,
    const HloFusionInstruction& fusion) const {
  EmitterHelper helper(description_, &computations, &call_targets,
                       entry_function, fusion);

  // Prepare the entry function.
  ImplicitLocOpBuilder b(entry_function.getLoc(), entry_function);
  b.setInsertionPointToStart(entry_function.addEntryBlock());
  auto thread_and_block_ids = EmitThreadAndBlockIds(b);
  Value output_tensor = entry_function.getArguments().back();

  IndexingMap updates_map = IndexingMap::GetUndefined();
  IndexingMap indices_map = IndexingMap::GetUndefined();
  ComputeIndexing(mlir_context_, &updates_map, &indices_map);
  updates_map.Simplify();

  return EmitEntryFunctionImpl(b, helper, updates_map, indices_map,
                               thread_and_block_ids, output_tensor);
}

// Emits an inbounds check and a loop over updates inside it. Does not do any
// accumulation.
void EmitNaiveImplementation(ImplicitLocOpBuilder& b,
                             const ScatterDescription& description,
                             const EmitterHelper& helper,
                             const IndexingMap& updates_map,
                             const IndexingMap& indices_map,
                             ValueRange thread_and_block_ids,
                             Value output_tensor) {
  MLIRContext* mlir_context = b.getContext();
  auto thread_id_to_update_id_map = IndexingMap(
      AffineMap::get(6, 0, {updates_map.GetAffineMap().getResult(0)},
                     mlir_context),
      updates_map.GetDimVars(),
      /*range_vars = */ {}, /*rt vars = */ {});
  Value thread_id_to_index_id_value =
      emitters::ApplyIndexing(thread_id_to_update_id_map, thread_and_block_ids,
                              {}, b)
          .front();
  Value index_id_in_bounds = b.createOrFold<arith::CmpIOp>(
      arith::CmpIPredicate::ult, thread_id_to_index_id_value,
      arith::ConstantIndexOp::create(b, description.num_slices));
  auto result = EmitUpdateIf(
      b, index_id_in_bounds, {output_tensor},
      [&](ImplicitLocOpBuilder& outer_nested_b) -> SmallVector<Value> {
        SmallVector<Value, 4> update_offsets =
            helper.ExtractOffsets(outer_nested_b, thread_id_to_index_id_value);

        Value in_bounds =
            EmitBoundsCheck(outer_nested_b, description.slice_shape,
                            description.output_shape, update_offsets);

        ValueRange predicated_update = EmitUpdateIf(
            outer_nested_b, in_bounds, {output_tensor},
            [&](ImplicitLocOpBuilder& nested_b) -> SmallVector<Value> {
              return EmitXlaLoopOp(
                  nested_b, thread_and_block_ids, {output_tensor}, updates_map,
                  [&](ImplicitLocOpBuilder& update_loop_b,
                      ValueRange symbol_values, ValueRange map_results,
                      ValueRange output_tensors) -> SmallVector<Value> {
                    // Extract update element.
                    auto update_elem =
                        helper.GetUpdateElement(update_loop_b, map_results);
                    auto output_indices = std::move(update_offsets);
                    int64_t output_rank = description.output_shape.size();
                    output_indices = PadWithZeros(output_indices, output_rank,
                                                  update_loop_b);
                    for (int i = 0; i < output_indices.size(); ++i) {
                      output_indices[i] = arith::AddIOp::create(
                          update_loop_b, map_results[i + 1], output_indices[i]);
                    }
                    Value output_tensor = output_tensors.front();
                    Value updated_output = helper.EmitScatterComputation(
                        update_loop_b, output_indices, update_elem,
                        output_tensor);
                    return {updated_output};
                  });
            });
        return predicated_update;
      });
  ReturnOp::create(b, result.front());
}

absl::Status ScatterWithDistributedUpdates::EmitEntryFunctionImpl(
    ImplicitLocOpBuilder& b, const EmitterHelper& helper,
    const IndexingMap& updates_map, const IndexingMap& indices_map,
    ValueRange thread_and_block_ids, Value output_tensor) const {
  if (VLOG_IS_ON(5)) {
    llvm::errs() << "Settings for ScatterWithDistributedUpdates: \n"
                 << "vector_size: " << vector_size_ << "\n"
                 << "num_warps: " << num_warps_ << "\n"
                 << "num_blocks: " << num_blocks_ << "\n";
  }
  EmitNaiveImplementation(b, description_, helper, updates_map, indices_map,
                          thread_and_block_ids, output_tensor);
  return absl::OkStatus();
}

ScatterWithDistributedIndices::ScatterWithDistributedIndices(
    const HloFusionAnalysis& analysis, const ScatterDescription& description,
    int64_t vector_size, int64_t num_warps_per_slice,
    int64_t num_indices_per_warp, int64_t indices_vector_size,
    MLIRContext* mlir_context)
    : ScatterFusion(analysis, description, vector_size, mlir_context),
      num_warps_per_slice_(num_warps_per_slice),
      num_indices_per_warp_(num_indices_per_warp),
      indices_vector_size_(indices_vector_size) {
  num_warps_ = kNumWarpsPerBlock;
  num_blocks_ = CeilOfRatio(description.num_slices * num_warps_per_slice_,
                            num_indices_per_warp_ * num_warps_);
}

void ScatterWithDistributedIndices::ComputeIndexing(
    MLIRContext* mlir_context, IndexingMap* updates_map,
    IndexingMap* indices_map) const {
  // Compute thread id mapping based on the first update operand.
  auto thread_x = getAffineDimExpr(
      KernelFusionInterface::kIndexingMapThreadIdxDims[0], mlir_context);
  auto block_x = getAffineDimExpr(
      KernelFusionInterface::kIndexingMapBlockIdxDims[0], mlir_context);
  auto warp_id = thread_x.floorDiv(warp_size_);
  auto slice_id =
      (block_x * num_warps_ + warp_id).floorDiv(num_warps_per_slice_);
  auto warp_id_in_slice =
      (block_x * num_warps_ + warp_id) % num_warps_per_slice_;
  auto lane_id = thread_x % warp_size_;
  auto index_id_loop = getAffineSymbolExpr(0, mlir_context);
  auto index_vector_id = getAffineSymbolExpr(1, mlir_context);

  auto vectorized_index_id_expr = slice_id * num_indices_per_warp_ +
                                  index_id_loop * indices_vector_size_ +
                                  index_vector_id;

  auto grid_vars =
      DimVarsFromGPUGrid({num_warps_ * warp_size_, 1, 1, num_blocks_, 1, 1});
  if (indices_map) {
    auto index_dim_loop = getAffineSymbolExpr(2, mlir_context);
    *indices_map = IndexingMap{
        AffineMap::get(6, 3, {vectorized_index_id_expr, index_dim_loop},
                       mlir_context),
        grid_vars,
        {IndexingMap::Variable{
             {0, num_indices_per_warp_ / indices_vector_size_ - 1},
             "index_id_loop"},
         IndexingMap::Variable{{0, indices_vector_size_ - 1},
                               "index_vector_id"},
         IndexingMap::Variable{{0, description_.index_vector_length - 1},
                               "index_dim"}},
        /*rt_vars=*/{},
        {std::make_pair(vectorized_index_id_expr,
                        Interval{0, description_.num_slices - 1})}};

    indices_map->Simplify();
  }

  if (updates_map) {
    auto index_id = getAffineSymbolExpr(0, mlir_context);
    auto update_dim_loop = getAffineSymbolExpr(1, mlir_context);
    auto vector_id = getAffineSymbolExpr(2, mlir_context);
    auto num_elements_per_slice = Product(description_.slice_shape);

    auto index_id_expr = slice_id * num_indices_per_warp_ + index_id;
    auto linear_slice_index =
        warp_id_in_slice * warp_size_ * vector_size_ +
        update_dim_loop * vector_size_ * warp_size_ * num_warps_per_slice_ +
        lane_id * vector_size_ + vector_id;

    SmallVector<AffineExpr, 4> updates_indexing = {index_id_expr};
    updates_indexing.append(
        DelinearizeInBoundsIndex(linear_slice_index, description_.slice_shape));

    *updates_map = IndexingMap{
        AffineMap::get(6, 3, updates_indexing, mlir_context),
        grid_vars,
        {IndexingMap::Variable{{0, num_indices_per_warp_ - 1}, "index_id_loop"},
         IndexingMap::Variable{
             {0, CeilOfRatio(num_elements_per_slice,
                             num_warps_per_slice_ * warp_size_ * vector_size_) -
                     1},
             "update_loop"},
         IndexingMap::Variable{{0, vector_size_ - 1}, "vector_id"}},
        /*rt_vars=*/{},
        std::vector<std::pair<AffineExpr, Interval>>{
            std::make_pair(index_id_expr,
                           Interval{0, description_.num_slices - 1}),
            std::make_pair(linear_slice_index,
                           Interval{0, num_elements_per_slice - 1})}};

    updates_map->Simplify();
  }
}

Value ScatterWithDistributedIndices::InitializeAccumulator(
    ImplicitLocOpBuilder& b) const {
  auto elem_type = emitters::PrimitiveTypeToMlirType(description_.elem_type, b);
  auto num_elements_per_slice = Product(description_.slice_shape);
  auto update_iterations_per_thread = CeilOfRatio(
      num_elements_per_slice, num_warps_per_slice_ * warp_size_ * vector_size_);
  auto accumulator_type =
      VectorType::get({update_iterations_per_thread, vector_size_}, elem_type);
  return arith::ConstantOp::create(
      b, accumulator_type,
      emitters::GetZeroDenseElementsAttr(accumulator_type));
}

absl::Status ScatterWithDistributedIndices::EmitEntryFunctionImpl(
    ImplicitLocOpBuilder& b, const EmitterHelper& helper,
    const IndexingMap& updates_map, const IndexingMap& indices_map,
    ValueRange thread_and_block_ids, Value output_tensor) const {
  if (VLOG_IS_ON(5)) {
    llvm::errs() << "Settings for ScatterWithDistributedIndices: \n"
                 << "vector_size: " << vector_size_ << "\n"
                 << "num_warps: " << num_warps_ << "\n"
                 << "num_blocks: " << num_blocks_ << "\n"
                 << "num_warps_per_slice: " << num_warps_per_slice_ << "\n"
                 << "num_indices_per_warp: " << num_indices_per_warp_ << "\n"
                 << "indices_vector_size: " << indices_vector_size_ << "\n";
  }
  if (num_indices_per_warp_ == 1) {
    EmitNaiveImplementation(b, description_, helper, updates_map, indices_map,
                            thread_and_block_ids, output_tensor);
    return absl::OkStatus();
  }
  MLIRContext* mlir_context = b.getContext();

  auto thread_id_to_update_id_map = IndexingMap(
      AffineMap::get(6, 2, {indices_map.GetAffineMap().getResult(0)},
                     mlir_context),
      indices_map.GetDimVars(),
      /*range_vars = */
      {indices_map.GetRangeVars().begin(),
       indices_map.GetRangeVars().begin() + 2},
      /*rt vars = */ {}, indices_map.GetConstraints());

  // Convert index_id_loop and index_vector_id to dimension variables.
  IndexingMap slice_indexing =
      ConvertRangeVariablesToDimensions(updates_map, {0});

  // Prepare loop initial values. Inits are packed as
  // [index_changed, is_inbounds, index_0,  ..., accumulator].
  Value is_inbounds_init = arith::ConstantIntOp::create(b, b.getI1Type(), 0);
  Value slice_id_init = arith::ConstantIndexOp::create(b, 0);
  std::vector<Value> indices_init(description_.index_vector_length,
                                  arith::ConstantIndexOp::create(b, -1));
  Value accumulator_init = InitializeAccumulator(b);
  SmallVector<Value> inits =
      Pack({slice_id_init, indices_init, is_inbounds_init, accumulator_init,
            output_tensor});

  int64_t output_rank = description_.output_shape.size();

  auto loop_over_indices_fn =
      [&](ImplicitLocOpBuilder& nested_b, ValueRange ivs,
          ValueRange thread_id_to_index_id_value,
          ValueRange outer_iter_args) -> SmallVector<Value> {
    // Unpack the iter_args.
    SmallVector<ValueRange> iter_args_unpack =
        Unpack(outer_iter_args, {1, description_.index_vector_length, 1, 1, 1});
    ValueRange trimmed_offsets = iter_args_unpack[1];
    Value iter_is_inbounds = iter_args_unpack[2].front();
    Value iter_acc = iter_args_unpack[3].front();
    Value iter_output = iter_args_unpack[4].front();
    CHECK_EQ(ivs.size(), 2);
    Value index_loop_id = ivs.front();
    Value index_vector_id = ivs.back();
    Value iter_slice_id = arith::AddIOp::create(
        nested_b,
        arith::MulIOp::create(
            nested_b, index_loop_id,
            arith::ConstantIndexOp::create(nested_b, indices_vector_size_)),
        index_vector_id);

    SmallVector<Value> offsets =
        PadWithZeros(trimmed_offsets, output_rank, nested_b);

    auto new_trimmed_offsets =
        helper.ExtractOffsets(nested_b, thread_id_to_index_id_value.front());

    // Check if the offsets changed.
    Value offsets_changed =
        EmitInequalityCheck(nested_b, trimmed_offsets, new_trimmed_offsets);

    for (int i = 0; i < description_.index_vector_length; ++i) {
      new_trimmed_offsets[i] =
          arith::SelectOp::create(nested_b, offsets_changed,
                                  new_trimmed_offsets[i], trimmed_offsets[i]);
    }

    auto new_offsets = PadWithZeros(new_trimmed_offsets, output_rank, nested_b);

    // Write accumulated values into the tensor if the offsets changed.
    Value is_not_first_iteration =
        arith::CmpIOp::create(b, arith::CmpIPredicate::ne, iter_slice_id,
                              arith::ConstantIndexOp::create(b, 0));
    Value write_to_output_required = arith::AndIOp::create(
        b, is_not_first_iteration,
        arith::AndIOp::create(b, offsets_changed, iter_is_inbounds));
    iter_output = helper.WriteAccumulatorToOutput(
        b, write_to_output_required, thread_and_block_ids, iter_slice_id,
        slice_indexing, offsets, iter_acc, iter_output);

    // Update `is_inbounds` if the offsets changed.
    Value new_is_inbounds = UpdateIsInbounds(
        nested_b, iter_is_inbounds, offsets_changed, new_offsets,
        description_.slice_shape, description_.output_shape);

    // Emits a loop that overwrites the accumulator with the new update elements
    // if the offsets changed.
    auto emit_overwrite_accumulator_fn = [&](OpBuilder& then_b,
                                             Location then_loc) -> void {
      ImplicitLocOpBuilder implicit_then_b(then_loc, then_b);
      auto then_results = EmitXlaLoopOp(
          implicit_then_b, Pack({thread_and_block_ids, iter_slice_id}),
          {iter_acc}, slice_indexing,
          [&](ImplicitLocOpBuilder& update_loop_b,
              ValueRange accumulator_indices, ValueRange slice_indices,
              ValueRange inner_iter_args) -> SmallVector<Value> {
            Value acc_arg = inner_iter_args.front();
            auto update_elem =
                helper.GetUpdateElement(update_loop_b, slice_indices);
            auto acc_ind_opfold = mlir::getAsOpFoldResult(accumulator_indices);
            return vector::InsertOp::create(update_loop_b, then_loc,
                                            update_elem, acc_arg,
                                            acc_ind_opfold)
                ->getResults();
          });
      scf::YieldOp::create(implicit_then_b, then_loc, then_results);
    };
    // Emits a loop that combines the accumulator with the new update elements
    // if the offsets did not change.
    auto emit_combine_accumulator_fn = [&](OpBuilder& else_b,
                                           Location else_loc) -> void {
      ImplicitLocOpBuilder implicit_else_b(else_loc, else_b);
      auto else_results = EmitXlaLoopOp(
          implicit_else_b, Pack({thread_and_block_ids, iter_slice_id}),
          {iter_acc}, slice_indexing,
          [&](ImplicitLocOpBuilder& update_loop_b,
              ValueRange accumulator_indices, ValueRange slice_indices,
              ValueRange inner_iter_args) -> SmallVector<Value> {
            Value acc_arg = inner_iter_args.front();
            auto update_elem =
                helper.GetUpdateElement(update_loop_b, slice_indices);
            auto acc_ind_opfold = mlir::getAsOpFoldResult(accumulator_indices);
            Value accumulator_elem = vector::ExtractOp::create(
                update_loop_b, acc_arg, acc_ind_opfold);
            auto reduced_val = emitters::InlineBlock(
                update_loop_b, helper.GetReducer().getBody().front(),
                {accumulator_elem, update_elem})[0];
            return vector::InsertOp::create(update_loop_b, reduced_val, acc_arg,
                                            acc_ind_opfold)
                ->getResults();
          });
      scf::YieldOp::create(implicit_else_b, else_results);
    };
    auto updated_accumulator =
        EmitUpdateIf(nested_b, new_is_inbounds, {iter_acc},
                     [&](ImplicitLocOpBuilder& if_b) {
                       return scf::IfOp::create(nested_b, offsets_changed,
                                                emit_overwrite_accumulator_fn,
                                                emit_combine_accumulator_fn)
                           .getResults();
                     })
            .front();
    SmallVector<Value> updated_if_loop_results =
        Pack({iter_slice_id, new_trimmed_offsets, new_is_inbounds,
              updated_accumulator, iter_output});
    return updated_if_loop_results;
  };
  auto loop_over_indices_results =
      EmitXlaLoopOp(b, thread_and_block_ids, inits, thread_id_to_update_id_map,
                    loop_over_indices_fn);

  // Write the accumulator to the output tensor.
  SmallVector<ValueRange> loop_over_indices_results_unpacked =
      Unpack(loop_over_indices_results,
             {1, description_.index_vector_length, 1, 1, 1});
  Value result_slice_id = loop_over_indices_results_unpacked[0].front();
  auto result_offsets =
      PadWithZeros(loop_over_indices_results_unpacked[1], output_rank, b);
  Value result_is_inbounds = loop_over_indices_results_unpacked[2].front();
  Value result_acc = loop_over_indices_results_unpacked[3].front();
  Value result_output = loop_over_indices_results_unpacked[4].front();
  result_output = helper.WriteAccumulatorToOutput(
      b, result_is_inbounds, thread_and_block_ids, result_slice_id,
      slice_indexing, result_offsets, result_acc, result_output);

  ReturnOp::create(b, result_output);
  return absl::OkStatus();
}

ScatterDescription GetScatterDescription(const HloFusionAnalysis& analysis) {
  auto* hero = &analysis.fusion_hero(0).instruction();
  CHECK_NE(hero, nullptr);
  auto* scatter = Cast<HloScatterInstruction>(hero);
  auto indices_shape = scatter->scatter_indices()->shape();
  auto update_shape = scatter->scatter_updates().front()->shape();
  auto output_shape = scatter->scatter_operands().front()->shape();

  return ScatterDescription{
      scatter,
      indices_shape.dimensions(0),
      indices_shape.dimensions(1),
      output_shape.element_type(),
      update_shape,
      SmallVector<int64_t, 2>(update_shape.dimensions().begin() + 1,
                              update_shape.dimensions().end()),
      SmallVector<int64_t, 2>(output_shape.dimensions().begin(),
                              output_shape.dimensions().end()),
  };
}

// Compute the maximal vector size that can be used to process the given number
// of elements in a single slice.
int64_t GetSingleSliceVectorSize(int64_t num_elements_in_slice,
                                 int64_t max_vectorized_elements,
                                 int64_t warp_size) {
  int64_t vector_size =
      std::gcd(num_elements_in_slice, max_vectorized_elements);
  int64_t num_processed_elememts_per_warp = warp_size * vector_size;
  while (vector_size > 1 &&
         num_processed_elememts_per_warp > num_elements_in_slice) {
    vector_size /= 2;
    num_processed_elememts_per_warp /= 2;
  }
  return vector_size;
}

int64_t GetNumPossibleValidIndices(absl::Span<const int64_t> slice_shape,
                                   absl::Span<const int64_t> output_shape,
                                   int64_t index_vector_length) {
  int64_t num_possible_valid_indices = 1;
  for (int64_t i = 0; i < index_vector_length; ++i) {
    num_possible_valid_indices *= output_shape[i] - slice_shape[i] + 1;
  }
  return num_possible_valid_indices;
}

std::unique_ptr<ScatterFusion> CreateScatterFusion(
    const HloFusionAnalysis& analysis, MLIRContext* mlir_context) {
  auto description = GetScatterDescription(analysis);
  int64_t warp_size = WarpSize(analysis.device_info());
  int64_t num_elements_per_slice = Product(description.slice_shape);
  int64_t num_slices = description.num_slices;

  // Initialize the vector size with the maximum allowed vector size that does
  // not require masking/padding.
  int64_t elem_type_bits = primitive_util::BitWidth(description.elem_type);
  CHECK_EQ(kMaxVectorizedBits % elem_type_bits, 0);
  int64_t max_vectorized_elements = kMaxVectorizedBits / elem_type_bits;
  int64_t vector_size = GetSingleSliceVectorSize(
      num_elements_per_slice, max_vectorized_elements, warp_size);
  int64_t max_active_warps =
      kNumWarpsPerBlock * analysis.device_info().core_count();

  // If indices are sorted and not unique, we can use the distributed indices
  // implementation to accumulate the updates before writing them to the output
  // tensor.
  if (description.scatter->indices_are_sorted() &&
      !description.scatter->unique_indices() && num_slices > max_active_warps) {
    int64_t num_indices_per_warp = 1;
    int64_t indices_vector_size = 1;
    int64_t num_warps_per_slice = 1;
    // We try to estimate the number of updates per warp by computing the ratio
    // of the number of the given updates to the number of the possible valid
    // indices. If we do not have multiple updates per warp, there is no reason
    // to use this algorithm.
    num_indices_per_warp = CeilOfRatio(
        num_slices,
        std::max(max_active_warps,
                 GetNumPossibleValidIndices(description.slice_shape,
                                            description.output_shape,
                                            description.index_vector_length)));

    // If the index_vector_length is 1, we can vectorize the indices read.
    int64_t index_elem_type_bits = primitive_util::BitWidth(
        description.scatter->scatter_indices()->shape().element_type());
    int64_t max_vectorized_indices = kMaxVectorizedBits / index_elem_type_bits;
    if (description.index_vector_length == 1 &&
        num_indices_per_warp > max_vectorized_indices) {
      // Pad num_indices_per_warp to the next multiple of
      // max_vectorized_indices.
      num_indices_per_warp =
          CeilOfRatio(num_indices_per_warp, max_vectorized_indices) *
          max_vectorized_indices;
      indices_vector_size = max_vectorized_indices;
    }
    return std::make_unique<ScatterWithDistributedIndices>(
        analysis, description, vector_size, num_warps_per_slice,
        num_indices_per_warp, indices_vector_size, mlir_context);
  }
  // Otherwise, we distribute the linearized updates tensor.
  vector_size =
      std::gcd(num_elements_per_slice,
               ComputeLoopFusionConfig(analysis, description.update_shape)
                   .unroll_factor);
  return std::make_unique<ScatterWithDistributedUpdates>(
      analysis, description, vector_size, mlir_context);
}

}  // namespace gpu
}  // namespace xla
