/**
 * Copyright 2017-2026, XGBoost contributors
 */
#include "test_quantile_obj.h"

#include <xgboost/base.h>        // for Args
#include <xgboost/context.h>     // for Context
#include <xgboost/data.h>        // for MetaInfo
#include <xgboost/objective.h>   // for ObjFunction
#include <xgboost/span.h>        // for Span
#include <xgboost/tree_model.h>  // for RegTree

#include <memory>  // for unique_ptr
#include <vector>  // for vector

#include "../helpers.h"                   // CheckConfigReload,MakeCUDACtx,DeclareUnifiedTest

namespace xgboost {
void TestQuantile(Context const* ctx) {
  {
    Args args{{"quantile_alpha", "[0.6, 0.8]"}};
    std::unique_ptr<ObjFunction> obj{ObjFunction::Create("reg:quantileerror", ctx)};
    obj->Configure(args);
    CheckConfigReload(obj, "reg:quantileerror");
  }

  Args args{{"quantile_alpha", "0.6"}};
  std::unique_ptr<ObjFunction> obj{ObjFunction::Create("reg:quantileerror", ctx)};
  obj->Configure(args);
  CheckConfigReload(obj, "reg:quantileerror");

  std::vector<float> predts{1.0f, 2.0f, 3.0f};
  std::vector<float> labels{3.0f, 2.0f, 1.0f};
  std::vector<float> weights{1.0f, 1.0f, 1.0f};
  std::vector<float> grad{-0.6f, 0.4f, 0.4f};
  std::vector<float> hess = weights;
  CheckObjFunction(obj, predts, labels, weights, grad, hess);
}

void TestQuantileIntercept(Context const* ctx) {
  Args args{{"quantile_alpha", "[0.6, 0.8]"}};
  std::unique_ptr<ObjFunction> obj{ObjFunction::Create("reg:quantileerror", ctx)};
  obj->Configure(args);

  MetaInfo info;
  info.num_row_ = 10;
  info.labels.ModifyInplace([&](HostDeviceVector<float>* data, common::Span<std::size_t> shape) {
    data->SetDevice(ctx->Device());
    data->Resize(info.num_row_);
    shape[0] = info.num_row_;
    shape[1] = 1;

    auto& h_labels = data->HostVector();
    for (std::size_t i = 0; i < info.num_row_; ++i) {
      h_labels[i] = i;
    }
  });

  linalg::Vector<float> base_scores;
  obj->InitEstimation(info, &base_scores);
  ASSERT_EQ(base_scores.Size(), 2);
  ASSERT_NEAR(base_scores(0), 5.6, kRtEps);
  ASSERT_NEAR(base_scores(1), 7.8, kRtEps);

  for (std::size_t i = 0; i < info.num_row_; ++i) {
    info.weights_.HostVector().emplace_back(info.num_row_ - i - 1.0);
  }

  obj->InitEstimation(info, &base_scores);
  ASSERT_EQ(base_scores.Size(), 2);
  ASSERT_NEAR(base_scores(0), 3.0, kRtEps);
  ASSERT_NEAR(base_scores(1), 5.0, kRtEps);
}
}  // namespace xgboost
