/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/util/autotune_maps/conv_parameters.h"

#include <cstddef>
#include <vector>

#include "absl/strings/str_format.h"
#include "absl/strings/str_replace.h"
#include "xla/tsl/lib/strings/proto_serialization.h"
#include "tensorflow/core/platform/hash.h"
#include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h"
#include "tsl/platform/regexp.h"

namespace tensorflow {

namespace {
using ::tsl::protobuf::util::MessageDifferencer;

uint64_t ComputeHash(int device_id, const ConvParametersProto& proto) {
  return Hash64Combine(device_id, tsl::DeterministicProtoHash64(proto));
}

uint64_t ComputeHash(int device_id, const MatmulParametersProto& proto) {
  return Hash64Combine(device_id, tsl::DeterministicProtoHash64(proto));
}
}  // namespace

std::string DeviceIdentifierForAutotuning(absl::string_view device_identifier) {
  static const LazyRE2 kDevicePattern = {
      R"regexp(^sm_[^ ]+ with ([\d]+)B RAM, [\d]+ cores, [\d]+KHz clock, [\d]+KHz mem clock, [\d]+B L2\$$)regexp"};

  uint64_t ram_bytes;
  if (!RE2::FullMatch(device_identifier, *kDevicePattern, &ram_bytes)) {
    return std::string(device_identifier);
  }

  // Round up RAM to GB with 1 decimal place.
  double ram_gb = static_cast<double>(ram_bytes) / 1e9;
  return absl::StrReplaceAll(device_identifier,
                             {{absl::StrCat(ram_bytes, "B RAM"),
                               absl::StrFormat("%.1fGB RAM", ram_gb)}});
}

ConvParameters::ConvParameters(
    se::StreamExecutor* stream_exec, int64_t batch, int64_t in_depths,
    const absl::Span<const int64_t> in, int data_format, int64_t out_depths,
    const absl::Span<const int64_t> filter,
    const absl::Span<const int64_t> dilation,
    const absl::Span<const int64_t> stride,
    const absl::Span<const int64_t> padding, DataType dtype, int group_count,
    absl::optional<ConvParameters::FusionInfo> fusion_info, int version)
    : device_id_(stream_exec->device_ordinal()) {
  proto_.set_batch(batch);
  proto_.set_in_depths(in_depths);
  *proto_.mutable_in() = {in.begin(), in.end()};
  proto_.set_data_format(static_cast<int>(data_format));
  proto_.set_out_depths(out_depths);
  *proto_.mutable_filter() = {filter.begin(), filter.end()};
  *proto_.mutable_dilation() = {dilation.begin(), dilation.end()};
  *proto_.mutable_stride() = {stride.begin(), stride.end()};
  *proto_.mutable_padding() = {padding.begin(), padding.end()};
  proto_.set_dtype(dtype);
  proto_.set_group_count(group_count);
  if (fusion_info.has_value()) {
    ConvParametersProto::Fusion fusion_proto;
    fusion_proto.set_conv_scale(fusion_info.value().conv_scale);
    fusion_proto.set_side_input_scale(fusion_info.value().side_input_scale);
    fusion_proto.set_activation_mode(fusion_info.value().activation_mode);
    fusion_proto.set_is_contrib(fusion_info.value().is_contrib);
    *proto_.mutable_fusion() = fusion_proto;
  }
  proto_.set_device_identifier(DeviceIdentifierForAutotuning(
      stream_exec->GetDeviceDescription().model_str()));
  proto_.set_version(version);
  hash_code_ = ComputeHash(device_id_, proto_);
}
ConvParameters::ConvParameters(int device_id, const ConvParametersProto& proto)
    : device_id_(device_id),
      proto_(proto),
      hash_code_(ComputeHash(device_id_, proto_)) {}

bool ConvParameters::operator==(const ConvParameters& other) const {
  return device_id_ == other.device_id_ &&
         MessageDifferencer::Equals(this->proto_, other.proto_);
}

std::string ConvParameters::ToString() const { return proto_.DebugString(); }

MatmulParameters::MatmulParameters(
    se::StreamExecutor* stream_exec, DataType ab_dtype, DataType c_dtype,
    bool trans_a, bool trans_b, uint64_t m, uint64_t n, uint64_t k, int64_t lda,
    int64_t ldb, int64_t ldc,
    stream_executor::dnn::ActivationMode activation_mode, int version)
    : device_id_(stream_exec->device_ordinal()) {
  proto_.set_ab_dtype(ab_dtype);
  proto_.set_c_dtype(c_dtype);

  proto_.set_trans_a(trans_a);
  proto_.set_trans_b(trans_b);
  proto_.set_m(m);
  proto_.set_n(n);
  proto_.set_k(k);
  proto_.set_lda(lda);
  proto_.set_ldb(ldb);
  proto_.set_ldc(ldc);
  proto_.set_activation_mode(activation_mode);

  proto_.set_device_identifier(DeviceIdentifierForAutotuning(
      stream_exec->GetDeviceDescription().model_str()));
  proto_.set_version(version);
  hash_code_ = ComputeHash(device_id_, proto_);
}

MatmulParameters::MatmulParameters(se::StreamExecutor* stream_exec,
                                   const MatmulParametersProto& proto)
    : device_id_(stream_exec->device_ordinal()),
      proto_(proto),
      hash_code_(ComputeHash(device_id_, proto_)) {}

bool MatmulParameters::operator==(const MatmulParameters& other) const {
  return device_id_ == other.device_id_ &&
         MessageDifferencer::Equals(this->proto_, other.proto_);
}

std::string MatmulParameters::ToString() const { return proto_.DebugString(); }

}  // namespace tensorflow

#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
