/* Copyright 2021 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "mhlo/utils/type_conversion.h"

#include <cassert>
#include <cstddef>

#include "llvm/ADT/STLExtras.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/StablehloOps.h"

namespace mlir {

class Value;

namespace mhlo {

namespace {

Type convertInteger(IntegerType intType) {
  return IntegerType::get(intType.getContext(),
                          intType.getIntOrFloatBitWidth());
}

Type convertShapedType(ShapedType shapedType) {
  if (auto intType = mlir::dyn_cast<IntegerType>(shapedType.getElementType()))
    return shapedType.clone(convertInteger(intType));
  return shapedType;
}

Value materializeCastFromIllegal(OpBuilder& builder, Type type,
                                                ValueRange inputs,
                                                Location loc) {
  Type fromType = getElementTypeOrSelf(inputs[0].getType());
  Type toType = getElementTypeOrSelf(type);
  if ((!fromType.isSignedInteger() && !fromType.isUnsignedInteger()) ||
      !toType.isSignlessInteger())
    return Value();
  // Use unrealized conversion casts to do signful->signless conversions.
  return UnrealizedConversionCastOp::create(builder, loc, type, inputs[0])
      ->getResult(0);
}

Value materializeCastToIllegal(OpBuilder& builder, Type type,
                                              ValueRange inputs, Location loc) {
  Type fromType = getElementTypeOrSelf(inputs[0].getType());
  Type toType = getElementTypeOrSelf(type);
  if (!fromType.isSignlessInteger() ||
      (!toType.isSignedInteger() && !toType.isUnsignedInteger()))
    return Value();
  // Use unrealized conversion casts to do signless->signful conversions.
  return UnrealizedConversionCastOp::create(builder, loc, type, inputs[0])
      ->getResult(0);
}

// Flatten the given value ranges into a single vector of values.
SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
  SmallVector<Value> result;
  for (const auto& vals : values) llvm::append_range(result, vals);
  return result;
}

// Exact same as `CallOpSignatureConversion`, except this one preserves
// discardable attributes.
struct CallOpSignatureConversion : public OpConversionPattern<func::CallOp> {
  using OpConversionPattern<func::CallOp>::OpConversionPattern;

  /// Hook for derived classes to implement combined matching and rewriting.
  LogicalResult matchAndRewrite(
      func::CallOp callOp, OneToNOpAdaptor adaptor,
      ConversionPatternRewriter& rewriter) const override {
    // Convert the original function results. Keep track of how many result
    // types an original result type is converted into.
    SmallVector<size_t> numResultsReplacements;
    SmallVector<Type, 1> convertedResults;
    size_t numFlattenedResults = 0;
    for (auto [idx, type] : llvm::enumerate(callOp.getResultTypes())) {
      if (failed(typeConverter->convertTypes(type, convertedResults)))
        return failure();
      numResultsReplacements.push_back(convertedResults.size() -
                                       numFlattenedResults);
      numFlattenedResults = convertedResults.size();
    }

    // Substitute with the new result types from the corresponding FuncType
    // conversion.
    auto newCallOp = func::CallOp::create(rewriter, callOp.getLoc(),
                                          callOp.getCallee(), convertedResults,
                                          flattenValues(adaptor.getOperands()));
    newCallOp->setAttrs(callOp->getAttrs());
    SmallVector<ValueRange> replacements;
    size_t offset = 0;
    for (int i = 0, e = callOp->getNumResults(); i < e; ++i) {
      replacements.push_back(
          newCallOp->getResults().slice(offset, numResultsReplacements[i]));
      offset += numResultsReplacements[i];
    }
    assert(offset == convertedResults.size() &&
           "expected that all converted results are used");
    rewriter.replaceOpWithMultiple(callOp, replacements);
    return success();
  }
};

}  // namespace

RemoveSignTypeConverter::RemoveSignTypeConverter() {
  addConversion([](Type type) { return type; });

  addConversion(convertInteger);
  addConversion(convertShapedType);

  addSourceMaterialization(materializeCastToIllegal);
  addTargetMaterialization(materializeCastFromIllegal);
}

}  // namespace mhlo

namespace stablehlo {

HloTypeConverter::HloTypeConverter() {
  addConversion([&](Type type) -> Type {
    // We cannot use an allowlist here because HLO dialects can be embedded
    // into programs with other dialects which can involve other types.
    // However, we restrict the use of types defined in the source dialect.
    // This check is here only for exceptional situations, e.g. when we added
    // a new type and forgot to update the converters in the subclass.
    if (isSourceDialect(type.getDialect())) return {};
    return type;
  });
  addConversion([&](RankedTensorType type) -> Type {
    auto encoding = type.getEncoding();
    if (!encoding) return type;

    // Since this type converter can be used in all sorts of programs,
    // we generally want to allow most of the encodings to pass through,
    // However, we restrict the use of encodings defined in the source dialect.
    if (isSourceDialect(encoding.getDialect())) {
      auto convertedEncoding = convertSourceDialectEncoding(encoding);
      if (!convertedEncoding) return {};
      return RankedTensorType::get(type.getShape(), type.getElementType(),
                                   convertedEncoding);
    }
    return type;
  });
  addConversion([&](TupleType type) -> Type {
    SmallVector<Type> convertedTypes;
    if (failed(convertTypes(type.getTypes(), convertedTypes))) return {};
    return TupleType::get(type.getContext(), convertedTypes);
  });
  // Similar to tuple, replace contents with StableHLO/MHLO types.
  addConversion([&](mhlo::AsyncBundleType bundle) -> Type {
    SmallVector<Type> convertedTypes;
    if (failed(convertTypes(bundle.getTypes(), convertedTypes))) return {};
    return mhlo::AsyncBundleType::get(bundle.getContext(), convertedTypes);
  });
}

HloToStablehloTypeConverter::HloToStablehloTypeConverter()
    : HloTypeConverter() {
  addConversion([](mhlo::TokenType type) -> Type {
    return stablehlo::TokenType::get(type.getContext());
  });
  // Consider implementing stablehlo::CustomType to provide an escape hatch
  // for modelling MHLO types that aren't yet in StableHLO.
  // Proposal: https://github.com/openxla/stablehlo/issues/743.
}

bool HloToStablehloTypeConverter::isSourceDialect(Dialect& dialect) {
  return dialect.getNamespace() == mhlo::MhloDialect::getDialectNamespace();
}

Attribute HloToStablehloTypeConverter::convertSourceDialectEncoding(
    Attribute attr) {
  if (auto hloAttr = mlir::dyn_cast_or_null<mhlo::TypeExtensionsAttr>(attr)) {
    return stablehlo::TypeExtensionsAttr::get(hloAttr.getContext(),
                                              hloAttr.getBounds());
  }
  // Our guiding principle is to support all MHLO encodings in StableHLO.
  // This check is here only for exceptional situations, e.g. when we added
  // a new MHLO encoding and forgot to update the code above.
  return {};
}

StablehloToHloTypeConverter::StablehloToHloTypeConverter()
    : HloTypeConverter(), convert_xla_supported_stablehlo_(true) {
  addConversion([](stablehlo::TokenType stablehloType) -> Type {
    return mhlo::TokenType::get(stablehloType.getContext());
  });
}

StablehloToHloTypeConverter::StablehloToHloTypeConverter(
    bool convertXlaSupportedStablehlo)
    : HloTypeConverter(),
      convert_xla_supported_stablehlo_(convertXlaSupportedStablehlo) {
  if (convert_xla_supported_stablehlo_) {
    addConversion([](stablehlo::TokenType stablehloType) -> Type {
      return mhlo::TokenType::get(stablehloType.getContext());
    });
  } else {
    addConversion([](stablehlo::TokenType stablehloType) -> Type {
      return stablehlo::TokenType::get(stablehloType.getContext());
    });
  }
}

bool StablehloToHloTypeConverter::isSourceDialect(Dialect& dialect) {
  return dialect.getNamespace() ==
         stablehlo::StablehloDialect::getDialectNamespace();
}

Attribute StablehloToHloTypeConverter::convertSourceDialectEncoding(
    Attribute attr) {
  if (auto stablehloAttr =
          mlir::dyn_cast_or_null<stablehlo::TypeExtensionsAttr>(attr)) {
    return mhlo::TypeExtensionsAttr::get(stablehloAttr.getContext(),
                                         stablehloAttr.getBounds());
  }
  // Our guiding principle is to support all StableHLO encodings in MHLO.
  // This check is here only for exceptional situations, e.g. when we added
  // a new StableHLO encoding and forgot to update the code above.
  return {};
}

void registerFuncOpsForTypeConversion(ConversionTarget& target,
                                      RewritePatternSet& patterns,
                                      TypeConverter& converter) {
  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
    return converter.isSignatureLegal(op.getFunctionType());
  });
  target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
    return converter.isSignatureLegal(op.getCalleeType());
  });
  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
    return converter.isLegal(op.getOperandTypes());
  });
  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
                                                                 converter);
  patterns.add<mhlo::CallOpSignatureConversion>(converter,
                                                patterns.getContext());
  populateReturnOpTypeConversionPattern(patterns, converter);
}

}  // namespace stablehlo

}  // namespace mlir
