/*
 * Copyright 2023 The OpenXLA Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "xla/python/ifrt_proxy/server/version.h"

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "absl/status/status_matchers.h"
#include "xla/python/ifrt/serdes_version.h"

namespace xla {
namespace ifrt {
namespace proxy {
namespace {

struct Param {
  int client_min_version;
  int client_max_version;
  int server_min_version;
  int server_max_version;
};

class CompatibleVersionTest : public ::testing::TestWithParam<Param> {};

TEST_P(CompatibleVersionTest, VerifyProtocolVersion) {
  const Param& param = GetParam();
  EXPECT_THAT(
      ChooseProtocolVersion(param.client_min_version, param.client_max_version,
                            param.server_min_version, param.server_max_version),
      absl_testing::IsOk());
}

TEST_P(CompatibleVersionTest, VerifyIfrtSerdesVersionNumber) {
  const Param& param = GetParam();
  EXPECT_THAT(ChooseIfrtSerdesVersionNumber(
                  SerDesVersionNumber(param.client_min_version),
                  SerDesVersionNumber(param.client_max_version),
                  SerDesVersionNumber(param.server_min_version),
                  SerDesVersionNumber(param.server_max_version)),
              absl_testing::IsOk());
}

INSTANTIATE_TEST_SUITE_P(CompatibleVersionTest, CompatibleVersionTest,
                         ::testing::Values(Param{1, 1, 1, 1}, Param{1, 2, 2, 2},
                                           Param{2, 2, 1, 2},
                                           Param{1, 3, 3, 4}));

class IncompatibleVersionTest : public ::testing::TestWithParam<Param> {};

TEST_P(IncompatibleVersionTest, VerifyProtocolVersion) {
  const Param& param = GetParam();
  EXPECT_THAT(
      ChooseProtocolVersion(param.client_min_version, param.client_max_version,
                            param.server_min_version, param.server_max_version),
      absl_testing::StatusIs(absl::StatusCode::kInvalidArgument));
}

TEST_P(IncompatibleVersionTest, VerifyIfrtSerdesVersionNumber) {
  const Param& param = GetParam();
  EXPECT_THAT(ChooseIfrtSerdesVersionNumber(
                  SerDesVersionNumber(param.client_min_version),
                  SerDesVersionNumber(param.client_max_version),
                  SerDesVersionNumber(param.server_min_version),
                  SerDesVersionNumber(param.server_max_version)),
              absl_testing::StatusIs(absl::StatusCode::kInvalidArgument));
}

INSTANTIATE_TEST_SUITE_P(IncompatibleVersionTest, IncompatibleVersionTest,
                         ::testing::Values(Param{1, 2, 3, 3}, Param{1, 3, 4, 6},
                                           Param{1, 1, 2, 2}));

}  // namespace
}  // namespace proxy
}  // namespace ifrt
}  // namespace xla
