// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package siliconflow

import (
	"encoding/json"
	"net/http"
	"net/http/httptest"
	"testing"

	"github.com/stretchr/testify/assert"
)

func TestEmbeddingOK(t *testing.T) {
	var res EmbeddingResponse
	repStr := `{
  "object": "list",
  "data": [
    {
      "object": "embedding",
      "embedding": [
        0.0,
        0.1
      ],
      "index": 0
    },
    {
      "object": "embedding",
      "embedding": [
        2.0,
        2.1
      ],
      "index": 2
    },
    {
      "object": "embedding",
      "embedding": [
        1.0,
        1.1
      ],
      "index": 1
    }
  ],
  "usage": {
    "total_tokens": 10,
    "completion_tokens": 123,
    "prompt_tokens": 123
  }
}`
	err := json.Unmarshal([]byte(repStr), &res)
	assert.NoError(t, err)
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
		data, _ := json.Marshal(res)
		w.Write(data)
	}))

	defer ts.Close()
	url := ts.URL

	{
		c, _ := NewSiliconflowClient("mock_key")
		ret, err := c.Embedding(url, "BAAI/bge-large-zh-v1.5", []string{"sentence"}, "float", 0, 0)
		assert.True(t, err == nil)
		assert.Equal(t, ret.Data[0].Index, 0)
		assert.Equal(t, ret.Data[1].Index, 1)
		assert.Equal(t, ret.Data[2].Index, 2)
		assert.Equal(t, ret.Data[0].Embedding, []float32{0.0, 0.1})
		assert.Equal(t, ret.Data[1].Embedding, []float32{1.0, 1.1})
		assert.Equal(t, ret.Data[2].Embedding, []float32{2.0, 2.1})
	}
}

func TestEmbeddingFailed(t *testing.T) {
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusUnauthorized)
	}))

	defer ts.Close()
	url := ts.URL

	{
		c, _ := NewSiliconflowClient("mock_key")
		_, err := c.Embedding(url, "BAAI/bge-large-zh-v1.5", []string{"sentence"}, "float", 0, 0)
		assert.True(t, err != nil)
	}
}

func TestRerankOK(t *testing.T) {
	var res RerankResponse
	repStr := `{
  "id": "xxx",
  "results": [
    {
      "index": 0,
      "relevance_score": 0.99184376
    },
	 {
      "index": 5,
      "relevance_score": 0.0034564096
    },
	 {
      "index": 3,
      "relevance_score": 0.0011
    }
  ],
  "meta": {
    "billed_units": {
      "input_tokens": 9,
      "output_tokens": 0,
      "search_units": 0,
      "classifications": 0
    },
    "tokens": {
      "input_tokens": 9,
      "output_tokens": 0
    }
  }
}`
	err := json.Unmarshal([]byte(repStr), &res)
	assert.NoError(t, err)
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
		data, _ := json.Marshal(res)
		w.Write(data)
	}))

	defer ts.Close()
	url := ts.URL

	{
		c, _ := NewSiliconflowClient("mock_key")
		ret, err := c.Rerank(url, "BAAI/bge-large-zh-v1.5", "query", []string{"text1", "text2", "text3"}, map[string]any{}, 0)
		assert.True(t, err == nil)
		assert.Equal(t, ret.Results[0].Index, 0)
		assert.Equal(t, ret.Results[0].RelevanceScore, float32(0.99184376))
		assert.Equal(t, ret.Results[1].Index, 3)
		assert.Equal(t, ret.Results[1].RelevanceScore, float32(0.0011))
		assert.Equal(t, ret.Results[2].Index, 5)
		assert.Equal(t, ret.Results[2].RelevanceScore, float32(0.0034564096))
		assert.Equal(t, ret.Meta.BilledUnits.InputTokens, 9)
		assert.Equal(t, ret.Meta.BilledUnits.OutputTokens, 0)
		assert.Equal(t, ret.Meta.BilledUnits.SearchUnits, 0)
		assert.Equal(t, ret.Meta.BilledUnits.Classifications, 0)
		assert.Equal(t, ret.Meta.Tokens.InputTokens, 9)
		assert.Equal(t, ret.Meta.Tokens.OutputTokens, 0)
	}
}

func TestRerankFailed(t *testing.T) {
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusUnauthorized)
	}))

	defer ts.Close()
	url := ts.URL

	{
		c, _ := NewSiliconflowClient("mock_key")
		_, err := c.Rerank(url, "BAAI/bge-large-zh-v1.5", "query", []string{"text1", "text2", "text3"}, map[string]any{}, 0)
		assert.True(t, err != nil)
	}
}
