/* Copyright 2023 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_PYTHON_IFRT_MOCK_H_
#define XLA_PYTHON_IFRT_MOCK_H_

#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "absl/base/no_destructor.h"
#include "absl/container/flat_hash_map.h"
#include "absl/hash/hash.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/Support/ExtensibleRTTI.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/testlib/test.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/pjrt_layout.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/array_spec.h"
#include "xla/python/ifrt/attribute_map.h"
#include "xla/python/ifrt/basic_device_list.h"
#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/compiler.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/device_list.h"
#include "xla/python/ifrt/dtype.h"
#include "xla/python/ifrt/executable.h"
#include "xla/python/ifrt/executable_serdes.h"
#include "xla/python/ifrt/host_callback.h"
#include "xla/python/ifrt/index_domain.h"
#include "xla/python/ifrt/layout.h"
#include "xla/python/ifrt/memory.h"
#include "xla/python/ifrt/mpmd_executable.h"
#include "xla/python/ifrt/program.h"
#include "xla/python/ifrt/remap_plan.h"
#include "xla/python/ifrt/shape.h"
#include "xla/python/ifrt/sharding.h"
#include "xla/python/ifrt/topology.h"
#include "xla/python/ifrt/tuple.h"
#include "xla/python/ifrt/user_context.h"
#include "xla/python/ifrt/value.h"
#include "xla/tsl/concurrency/future.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "xla/xla_data.pb.h"

namespace xla {
namespace ifrt {

// array.h

class MockArray : public llvm::RTTIExtends<MockArray, Array> {
 public:
  MockArray() = default;
  explicit MockArray(xla::ifrt::ArrayRef delegated);

  // LINT.IfChange
  MOCK_METHOD(Client*, client, (), (const, final));
  MOCK_METHOD(tsl::Future<>, GetReadyFuture, (), (const, final));
  MOCK_METHOD(tsl::Future<>, Delete, (), (final));
  MOCK_METHOD(bool, IsDeleted, (), (const, final));

  MOCK_METHOD(DType, dtype, (), (const, final));
  MOCK_METHOD(const Shape&, shape, (), (const, final));
  MOCK_METHOD(const Sharding&, sharding, (), (const, final));
  MOCK_METHOD(ShardingRef, shared_ptr_sharding, (), (const, final));
  MOCK_METHOD(absl::StatusOr<std::shared_ptr<const xla::PjRtLayout>>,
              pjrt_layout, (), (const, final));
  MOCK_METHOD(CustomLayoutRef, layout, (), (const, final));
  MOCK_METHOD(UserContextRef, user_context, (), (const, final));
  MOCK_METHOD(absl::StatusOr<std::vector<ArrayRef>>,
              DisassembleIntoSingleDeviceArrays,
              (ArrayCopySemantics array_copy_semantics,
               SingleDeviceShardSemantics single_device_shard_semantics),
              (final));
  MOCK_METHOD(absl::StatusOr<ArrayRef>, FullyReplicatedShard,
              (ArrayCopySemantics semantics), (final));
  MOCK_METHOD(tsl::Future<>, CopyToHostBuffer,
              (void* data,
               std::optional<absl::Span<const int64_t>> byte_strides,
               ArrayCopySemantics semantics),
              (final));
  // LINT.ThenChange(mock.cc:MockArrayDelegation)

  xla::ifrt::ArrayRef delegated() const { return delegated_; }

  std::string DebugString() const final { return "MockArray"; }

  static char ID;  // NOLINT

 private:
  const xla::ifrt::ArrayRef delegated_;
};

// client.h

class MockClient : public llvm::RTTIExtends<MockClient, Client> {
 public:
  MockClient() = default;
  explicit MockClient(std::unique_ptr<xla::ifrt::Client> delegated);

  // LINT.IfChange
  MOCK_METHOD(absl::StatusOr<ArrayRef>, MakeArrayFromHostBuffer,
              (const void* data, DType dtype, Shape shape,
               std::optional<absl::Span<const int64_t>> byte_strides,
               ShardingRef sharding, HostBufferSemantics semantics,
               std::function<void()> on_done_with_host_buffer),
              (final));
  MOCK_METHOD(absl::StatusOr<std::vector<ArrayRef>>,
              MakeArraysFromHostBufferShards,
              (absl::Span<MakeArraysFromHostBufferShardsSpec> specs,
               HostBufferSemantics semantics),
              (final));
  MOCK_METHOD(absl::StatusOr<std::vector<ArrayRef>>, MakeErrorArrays,
              (const absl::Status& error,
               absl::Span<const ArraySpec> array_specs),
              (final));
  MOCK_METHOD(absl::StatusOr<ArrayRef>, AssembleArrayFromSingleDeviceArrays,
              (DType dtype, Shape shape, ShardingRef sharding,
               absl::Span<ArrayRef> arrays,
               ArrayCopySemantics array_copy_semantics,
               SingleDeviceShardSemantics single_device_shard_semantics),
              (final));
  MOCK_METHOD(absl::StatusOr<std::vector<ArrayRef>>, CopyArrays,
              (absl::Span<ArrayRef> arrays,
               std::optional<DeviceListRef> devices,
               std::optional<MemoryKind> memory_kind,
               ArrayCopySemantics semantics),
              (final));
  MOCK_METHOD(absl::StatusOr<std::vector<ArrayRef>>, RemapArrays,
              (const RemapPlan& plan, absl::Span<ArrayRef> arrays,
               ArrayCopySemantics semantics),
              (final));
  MOCK_METHOD(absl::StatusOr<std::vector<ArrayRef>>, ReshardArrays,
              (absl::Span<ArrayRef> arrays, absl::Span<const ArraySpec> specs,
               ArrayCopySemantics semantics),
              (final));
  MOCK_METHOD(tsl::Future<>, GetReadyFuture,
              (absl::Span<const ValueRef> values), (final));
  MOCK_METHOD(absl::StatusOr<tsl::RCReference<Tuple>>, MakeTuple,
              (absl::Span<ValueRef> values), (final));
  MOCK_METHOD(
      void, CancelExecution,
      (xla::ifrt::LoadedExecutable::CancellationHandle cancellation_handle,
       absl::Status error),
      (final));
  MOCK_METHOD(absl::string_view, runtime_type, (), (const, final));
  MOCK_METHOD(absl::string_view, platform_name, (), (const, final));
  MOCK_METHOD(absl::string_view, platform_version, (), (const, final));
  MOCK_METHOD((const AttributeMap&), Attributes, (), (const, final));
  MOCK_METHOD(int, device_count, (), (const, final));
  MOCK_METHOD(PlatformId, platform_id, (), (const, final));
  MOCK_METHOD(int, addressable_device_count, (), (const, final));
  MOCK_METHOD(absl::Span<Device* const>, devices, (), (const, final));
  MOCK_METHOD(absl::Span<Device* const>, addressable_devices, (),
              (const, final));
  MOCK_METHOD(int, process_index, (), (const, final));
  MOCK_METHOD(absl::Span<xla::ifrt::Device* const>, GetAllDevices, (),
              (const, final));
  MOCK_METHOD(absl::StatusOr<DeviceAssignment>, GetDefaultDeviceAssignment,
              (int num_replicas, int num_partitions), (const, final));
  MOCK_METHOD(absl::StatusOr<Device*>, LookupDevice, (DeviceId device_id),
              (const, final));
  MOCK_METHOD(absl::StatusOr<Device*>, LookupAddressableDevice,
              (int local_hardware_id), (const, final));
  MOCK_METHOD(absl::StatusOr<DeviceListRef>, MakeDeviceList,
              (absl::Span<Device* const> devices), (const));
  MOCK_METHOD(Compiler*, GetDefaultCompiler, (), (final));
  MOCK_METHOD(absl::StatusOr<std::shared_ptr<Topology>>, GetTopologyForDevices,
              (const xla::ifrt::DeviceListRef& devices), (const, final));
  MOCK_METHOD(absl::StatusOr<std::shared_ptr<const xla::PjRtLayout>>,
              GetDefaultPjRtLayout,
              (xla::ifrt::DType dtype, absl::Span<const int64_t> dims,
               xla::ifrt::Device* device, xla::ifrt::MemoryKind memory_kind),
              (const, final));
  MOCK_METHOD(absl::StatusOr<CustomLayoutRef>, GetDefaultLayout,
              (DType dtype, const Shape& shape, const ShardingRef& sharding),
              (const, final));
  MOCK_METHOD(
      absl::StatusOr<std::unique_ptr<xla::ifrt::DeviceAttributeSubscription>>,
      SubscribeToAttributeChanges,
      (absl::Span<xla::ifrt::Device* const> devices,
       std::optional<absl::Span<const std::string>> attribute_names,
       xla::ifrt::OnDeviceAttributeChangeCallback callback),
      (final));
  // LINT.ThenChange(mock.cc:MockClientDelegation)

  xla::ifrt::Client* delegated() const { return delegated_.get(); }

  static char ID;  // NOLINT

 private:
  const std::unique_ptr<xla::ifrt::Client> delegated_;
};

// compiler.h

class MockCompiler : public llvm::RTTIExtends<MockCompiler, Compiler> {
 public:
  MOCK_METHOD(absl::StatusOr<ExecutableRef>, Compile,
              (std::unique_ptr<Program> program, const Topology& topology,
               std::unique_ptr<CompileOptions> options),
              (final));
  MOCK_METHOD(absl::StatusOr<LoadedExecutableRef>, CompileAndLoad,
              (std::unique_ptr<Program> program,
               std::unique_ptr<CompileOptions> options),
              (final));
  MOCK_METHOD(absl::Status, IsExecutableVersionCompatible,
              (const xla::ifrt::ExecutableVersion& executable_version,
               const xla::ifrt::DeviceListRef& devices),
              (const, final));
  MOCK_METHOD(absl::StatusOr<LoadedExecutableRef>, DeserializeLoadedExecutable,
              (absl::string_view serialized,
               std::unique_ptr<DeserializeExecutableOptions> options),
              (final));

  static char ID;  // NOLINT
};

// device.h

class MockDevice : public Device {
 public:
  MockDevice() = default;
  explicit MockDevice(Device* delegated);

  // LINT.IfChange
  MOCK_METHOD(Client*, client, (), (const, final));
  MOCK_METHOD(bool, IsAddressable, (), (const, final));
  MOCK_METHOD(int, ProcessIndex, (), (const, final));
  MOCK_METHOD(DeviceId, Id, (), (const, final));
  MOCK_METHOD(absl::string_view, PlatformName, (), (const, final));
  MOCK_METHOD(absl::string_view, Kind, (), (const, final));
  MOCK_METHOD((const AttributeMap&), Attributes, (), (const, final));
  MOCK_METHOD(absl::StatusOr<Memory*>, DefaultMemory, (), (const, final));
  MOCK_METHOD(absl::Span<Memory* const>, Memories, (), (const, final));
  // LINT.ThenChange(mock.cc:MockDeviceDelegation)

  Device* delegated() const { return delegated_; }

  absl::string_view DebugString() const final { return "MockDevice"; }
  absl::string_view ToString() const final { return "MockDevice"; }

 private:
  Device* const delegated_ = nullptr;
};

// device_list.h

class MockDeviceList : public DeviceList {
 public:
  MockDeviceList() = default;
  ~MockDeviceList() override = default;

  MOCK_METHOD(absl::Span<Device* const>, devices, (), (const final));
  MOCK_METHOD(DeviceList*, AddressableDeviceList, (), (const final));

  MOCK_METHOD(bool, EqualEqualOperator, (const DeviceList& other),
              (const final));
  bool operator==(const DeviceList& other) const override {
    return EqualEqualOperator(other);
  }
  MOCK_METHOD(uint64_t, hash, (), (const final));
  MOCK_METHOD(uint64_t, fingerprint, (), (const final));
  MOCK_METHOD(std::string, ToString, (), (const final));
};

// memory.h

class MockMemory : public Memory {
 public:
  MOCK_METHOD(MemoryId, Id, (), (const, final));
  MOCK_METHOD(absl::Span<Device* const>, Devices, (), (const, final));
  MOCK_METHOD(const MemoryKind&, Kind, (), (const, final));
  MOCK_METHOD(absl::string_view, ToString, (), (const, final));

  absl::string_view DebugString() const final { return "MockMemory"; }
};

// executable.h

class MockExecutable : public llvm::RTTIExtends<MockExecutable, Executable> {
 public:
  MOCK_METHOD(absl::string_view, name, (), (const, final));
  MOCK_METHOD(absl::StatusOr<std::optional<std::string>>, Fingerprint, (),
              (const, final));
  MOCK_METHOD(absl::StatusOr<std::string>, Serialize, (), (const, final));
  MOCK_METHOD(int, num_devices, (), (const, final));
  MOCK_METHOD(int64_t, SizeOfGeneratedCodeInBytes, (), (const, final));
  MOCK_METHOD(absl::StatusOr<CompiledMemoryStats>, GetCompiledMemoryStats, (),
              (const, final));
  MOCK_METHOD(std::optional<std::vector<OpSharding>>, GetParameterShardings, (),
              (const, final));
  MOCK_METHOD(std::optional<std::vector<OpSharding>>, GetOutputShardings, (),
              (const, final));
  MOCK_METHOD(
      absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>,
      GetParameterLayouts, (), (const, final));
  MOCK_METHOD(
      absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>,
      GetOutputLayouts, (), (const, final));
  MOCK_METHOD(absl::StatusOr<std::vector<std::shared_ptr<HloModule>>>,
              GetHloModules, (), (const, final));
  MOCK_METHOD(absl::StatusOr<xla::ifrt::AttributeMap>, GetCostAnalysis, (),
              (const, final));

  static char ID;  // NOLINT
};

class MockLoadedExecutable
    : public llvm::RTTIExtends<MockLoadedExecutable, LoadedExecutable> {
 public:
  MOCK_METHOD(Client*, client, (), (const, final));
  MOCK_METHOD(absl::string_view, name, (), (const, final));
  MOCK_METHOD(absl::StatusOr<std::optional<std::string>>, Fingerprint, (),
              (const, final));
  MOCK_METHOD(absl::StatusOr<std::unique_ptr<xla::ifrt::ExecutableVersion>>,
              executable_version, (), (const, final));
  MOCK_METHOD(absl::StatusOr<std::string>, Serialize, (), (const, final));
  MOCK_METHOD(absl::StatusOr<std::string>, GetHumanReadableProgramText, (),
              (const, final));
  MOCK_METHOD(UserContextRef, user_context, (), (const, final));
  MOCK_METHOD(tsl::Future<>, GetReadyFuture, (), (const, override));
  MOCK_METHOD(int, num_devices, (), (const, final));
  MOCK_METHOD(int64_t, SizeOfGeneratedCodeInBytes, (), (const, final));
  MOCK_METHOD(absl::StatusOr<CompiledMemoryStats>, GetCompiledMemoryStats, (),
              (const, final));
  MOCK_METHOD(std::optional<std::vector<OpSharding>>, GetParameterShardings, (),
              (const, final));
  MOCK_METHOD(std::optional<std::vector<OpSharding>>, GetOutputShardings, (),
              (const, final));
  MOCK_METHOD(
      absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>,
      GetParameterLayouts, (), (const, final));
  MOCK_METHOD(absl::StatusOr<absl::Span<const int>>, GetDonatableInputIndices,
              (), (const, final));
  MOCK_METHOD(
      absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>,
      GetOutputLayouts, (), (const, final));
  MOCK_METHOD(absl::StatusOr<std::vector<std::vector<absl::string_view>>>,
              GetOutputMemoryKinds, (), (const, final));
  MOCK_METHOD(absl::StatusOr<std::vector<std::shared_ptr<HloModule>>>,
              GetHloModules, (), (const, final));
  MOCK_METHOD(absl::StatusOr<xla::ifrt::AttributeMap>, GetCostAnalysis, (),
              (const, final));
  MOCK_METHOD(absl::StatusOr<ExecuteResult>, Execute,
              (absl::Span<ArrayRef> args, const ExecuteOptions& options,
               std::optional<DeviceListRef> devices),
              (final));
  MOCK_METHOD(absl::Span<Device* const>, addressable_devices, (),
              (const, final));
  MOCK_METHOD(std::optional<DeviceListRef>, devices, (), (const, final));

  static char ID;  // NOLINT
};

class MockMpmdLoadedExecutable
    : public llvm::RTTIExtends<MockMpmdLoadedExecutable, MpmdLoadedExecutable> {
 public:
  MockMpmdLoadedExecutable() {
    static absl::NoDestructor<DeviceListRef> kEmptyDeviceList(
        BasicDeviceList::Create({}));
    ON_CALL(*this, devices()).WillByDefault(testing::Return(*kEmptyDeviceList));
  }

  MOCK_METHOD((absl::StatusOr<absl::flat_hash_map<
                   std::string, absl::Span<xla::ifrt::Device* const>>>),
              GetMpmdAddressableDevices, (), (const, final));
  MOCK_METHOD(
      (absl::StatusOr<absl::flat_hash_map<std::string, CompiledMemoryStats>>),
      GetMpmdCompiledMemoryStats, (), (const, final));
  MOCK_METHOD((absl::StatusOr<
                  absl::flat_hash_map<std::string, xla::ifrt::AttributeMap>>),
              GetMpmdCostAnalysis, (), (const, final));
  MOCK_METHOD((absl::StatusOr<absl::flat_hash_map<
                   std::string, std::vector<std::shared_ptr<HloModule>>>>),
              GetMpmdHloModules, (), (const, final));

  MOCK_METHOD(Client*, client, (), (const, final));
  MOCK_METHOD(absl::string_view, name, (), (const, final));
  MOCK_METHOD(absl::StatusOr<std::optional<std::string>>, Fingerprint, (),
              (const, final));
  MOCK_METHOD(absl::StatusOr<std::unique_ptr<xla::ifrt::ExecutableVersion>>,
              executable_version, (), (const, final));
  MOCK_METHOD(absl::StatusOr<std::string>, Serialize, (), (const, final));
  MOCK_METHOD(absl::StatusOr<std::string>, GetHumanReadableProgramText, (),
              (const, final));
  MOCK_METHOD(UserContextRef, user_context, (), (const, final));
  MOCK_METHOD(tsl::Future<>, GetReadyFuture, (), (const, override));
  MOCK_METHOD(int, num_devices, (), (const, final));
  MOCK_METHOD(int64_t, SizeOfGeneratedCodeInBytes, (), (const, final));
  MOCK_METHOD(absl::StatusOr<CompiledMemoryStats>, GetCompiledMemoryStats, (),
              (const, final));
  MOCK_METHOD(std::optional<std::vector<xla::OpSharding>>,
              GetParameterShardings, (), (const, final));
  MOCK_METHOD(std::optional<std::vector<xla::OpSharding>>, GetOutputShardings,
              (), (const, final));
  MOCK_METHOD(
      absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>,
      GetParameterLayouts, (), (const, final));
  MOCK_METHOD(absl::StatusOr<absl::Span<const int>>, GetDonatableInputIndices,
              (), (const, final));
  MOCK_METHOD(
      absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>,
      GetOutputLayouts, (), (const, final));
  MOCK_METHOD(absl::StatusOr<std::vector<std::vector<absl::string_view>>>,
              GetOutputMemoryKinds, (), (const, final));
  MOCK_METHOD(absl::StatusOr<std::vector<std::shared_ptr<xla::HloModule>>>,
              GetHloModules, (), (const, final));
  MOCK_METHOD(absl::StatusOr<xla::ifrt::AttributeMap>, GetCostAnalysis, (),
              (const, final));
  MOCK_METHOD(absl::StatusOr<ExecuteResult>, Execute,
              (absl::Span<ArrayRef> args, const ExecuteOptions& options,
               std::optional<DeviceListRef> devices),
              (final));
  MOCK_METHOD(absl::Span<Device* const>, addressable_devices, (),
              (const, final));
  MOCK_METHOD(std::optional<DeviceListRef>, devices, (), (const, final));

  static char ID;  // NOLINT
};

// host_callback.h

class MockHostCallback final
    : public llvm::RTTIExtends<MockHostCallback, HostCallback> {
 public:
  MOCK_METHOD(std::string, Serialize, (), (const, final));

  static char ID;  // NOLINT
};

class MockLoadedHostCallback final
    : public llvm::RTTIExtends<MockLoadedHostCallback, LoadedHostCallback> {
 public:
  MOCK_METHOD(Client*, client, (), (const, final));
  MOCK_METHOD(absl::StatusOr<std::string>, Serialize, (), (const, final));

  static char ID;  // NOLINT
};

// sharding.h

class MockSharding : public llvm::RTTIExtends<MockSharding, Sharding> {
 public:
  MockSharding()
      : llvm::RTTIExtends<MockSharding, Sharding>(
            BasicDeviceList::Create({}), MemoryKind(),
            /*is_fully_replicated=*/false) {}

  MockSharding(DeviceListRef devices, MemoryKind memory_kind,
               bool is_fully_replicated)
      : llvm::RTTIExtends<MockSharding, Sharding>(devices, memory_kind,
                                                  is_fully_replicated) {}

  MOCK_METHOD((absl::StatusOr<std::vector<std::pair<Shape, ShardingRef>>>),
              Disassemble,
              (const Shape& shape,
               SingleDeviceShardSemantics single_device_shard_semantics),
              (const, final));
  MOCK_METHOD(
      (absl::StatusOr<std::vector<std::pair<DynamicShape, ShardingRef>>>),
      Disassemble,
      (const DynamicShape& dynamic_shape,
       SingleDeviceShardSemantics single_device_shard_semantics),
      (const final));
  MOCK_METHOD(absl::StatusOr<std::vector<IndexDomain>>, IndexDomains,
              (const Shape& shape,
               SingleDeviceShardSemantics single_device_shard_semantics),
              (const, final));
  MOCK_METHOD(absl::StatusOr<Shape>, GetShardShape, (const Shape& shape),
              (const, final));
  MOCK_METHOD(bool, HasSamePartitioning, (const Sharding& other),
              (const final));
  MOCK_METHOD(absl::StatusOr<std::unique_ptr<Sharding>>, WithDeviceAssignment,
              (std::optional<DeviceListRef> devices,
               std::optional<MemoryKind> memory_kind),
              (const final));
  MOCK_METHOD(void, Hash, (absl::HashState), (const final));

  std::string DebugString() const final { return "MockSharding"; }

  static char ID;  // NOLINT
};

}  // namespace ifrt
}  // namespace xla

#endif  // XLA_PYTHON_IFRT_MOCK_H_
