/* Copyright 2018 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/stream_executor/rocm/rocm_platform.h"

#include <memory>
#include <string>
#include <utility>

#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "rocm/include/hip/hip_runtime.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform/initialize.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/stream_executor/rocm/rocm_driver_wrapper.h"
#include "xla/stream_executor/rocm/rocm_executor.h"
#include "xla/stream_executor/rocm/rocm_platform_id.h"
#include "xla/stream_executor/rocm/rocm_status.h"
#include "xla/tsl/platform/errors.h"

namespace stream_executor {
namespace gpu {
namespace {

// Actually performs the work of ROCM initialization. Wrapped up in one-time
// execution guard.
static absl::Status InternalInitialize() {
  hipError_t res = wrap::hipInit(0 /* = flags */);

  if (res == hipSuccess) {
    return absl::OkStatus();
  }

  LOG(ERROR) << "failed call to hipInit: " << ToString(res);
  return absl::AbortedError(
      absl::StrCat("failed call to hipInit: ", ToString(res)));
}

static absl::Status PlatformInitialize() {
  // Cached return value from calling InternalInitialize(), as hipInit need only
  // be called once, but PlatformInitialize may be called many times.
  static absl::Status* init_retval = [] {
    return new absl::Status(InternalInitialize());
  }();
  return *init_retval;
}
}  // namespace

ROCmPlatform::ROCmPlatform() : name_("ROCM") {}

Platform::Id ROCmPlatform::id() const { return rocm::kROCmPlatformId; }

int ROCmPlatform::VisibleDeviceCount() const {
  // Throw away the result - it logs internally, and this [containing] function
  // isn't in the path of user control. It's safe to call this > 1x.

  if (!PlatformInitialize().ok()) {
    return -1;
  }

  int device_count = 0;
  hipError_t res = wrap::hipGetDeviceCount(&device_count);
  if (res != hipSuccess) {
    LOG(ERROR) << "could not retrieve ROCM device count: " << ToString(res);
    return 0;
  }

  return device_count;
}

const std::string& ROCmPlatform::Name() const { return name_; }

absl::StatusOr<std::unique_ptr<DeviceDescription>>
ROCmPlatform::DescriptionForDevice(int ordinal) const {
  TF_RETURN_IF_ERROR(PlatformInitialize());
  return RocmExecutor::CreateDeviceDescription(ordinal);
}

absl::StatusOr<StreamExecutor*> ROCmPlatform::ExecutorForDevice(int ordinal) {
  TF_RETURN_IF_ERROR(PlatformInitialize());
  return executor_cache_.GetOrCreate(
      ordinal, [this, ordinal]() { return GetUncachedExecutor(ordinal); });
}

absl::StatusOr<StreamExecutor*> ROCmPlatform::FindExisting(int ordinal) {
  return executor_cache_.Get(ordinal);
}

absl::StatusOr<std::unique_ptr<StreamExecutor>>
ROCmPlatform::GetUncachedExecutor(int ordinal) {
  auto executor = std::make_unique<RocmExecutor>(this, ordinal);
  TF_RETURN_IF_ERROR(executor->Init());
  return std::move(executor);
}

}  // namespace gpu

static void InitializeROCmPlatform() {
  auto status = PlatformManager::PlatformWithName("ROCM");
  if (!status.ok()) {
    CHECK_OK(PlatformManager::RegisterPlatform(
        std::make_unique<gpu::ROCmPlatform>()));
  }
}

}  // namespace stream_executor

STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(
    rocm_platform, stream_executor::InitializeROCmPlatform());
