/*
 * SPDX-License-Identifier: Apache-2.0
 *
 * The OpenSearch Contributors require contributions made to
 * this file be licensed under the Apache-2.0 license or a
 * compatible open source license.
 */

/*
 * Licensed to Elasticsearch under one or more contributor
 * license agreements. See the NOTICE file distributed with
 * this work for additional information regarding copyright
 * ownership. Elasticsearch licenses this file to you under
 * the Apache License, Version 2.0 (the "License"); you may
 * not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*
 * Modifications Copyright OpenSearch Contributors. See
 * GitHub history for details.
 */

package org.opensearch.index.rankeval;

import org.opensearch.action.OriginalIndices;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParseException;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import static org.opensearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
import static org.opensearch.test.XContentTestUtils.insertRandomFields;
import static org.hamcrest.CoreMatchers.containsString;

public class ExpectedReciprocalRankTests extends OpenSearchTestCase {

    private static final double DELTA = 10E-14;

    public void testProbabilityOfRelevance() {
        ExpectedReciprocalRank err = new ExpectedReciprocalRank(5);
        assertEquals(0.0, err.probabilityOfRelevance(0), 0.0);
        assertEquals(1d / 32d, err.probabilityOfRelevance(1), 0.0);
        assertEquals(3d / 32d, err.probabilityOfRelevance(2), 0.0);
        assertEquals(7d / 32d, err.probabilityOfRelevance(3), 0.0);
        assertEquals(15d / 32d, err.probabilityOfRelevance(4), 0.0);
        assertEquals(31d / 32d, err.probabilityOfRelevance(5), 0.0);
    }

    /**
     * Assuming the result ranking is
     *
     * <pre>{@code
     * rank | relevance | probR / r | p        | p * probR / r
     * -------------------------------------------------------
     * 1    | 3         | 0.875     | 1        | 0.875       |
     * 2    | 2         | 0.1875    | 0.125    | 0.0234375   |
     * 3    | 0         | 0         | 0.078125 | 0           |
     * 4    | 1         | 0.03125   | 0.078125 | 0.00244140625 |
     * }</pre>
     *
     * err = sum of last column
     */
    public void testERRAt() {
        List<RatedDocument> rated = new ArrayList<>();
        Integer[] relevanceRatings = new Integer[] { 3, 2, 0, 1 };
        SearchHit[] hits = createSearchHits(rated, relevanceRatings);
        ExpectedReciprocalRank err = new ExpectedReciprocalRank(3, 0, 3);
        assertEquals(0.8984375, err.evaluate("id", hits, rated).metricScore(), DELTA);
        // take 4th rank into window
        err = new ExpectedReciprocalRank(3, 0, 4);
        assertEquals(0.8984375 + 0.00244140625, err.evaluate("id", hits, rated).metricScore(), DELTA);
    }

    /**
     * Assuming the result ranking is
     *
     * <pre>{@code
     * rank | relevance | probR / r | p        | p * probR / r
     * -------------------------------------------------------
     * 1    | 3         | 0.875     | 1        | 0.875       |
     * 2    | n/a       | n/a       | 0.125    | n/a   |
     * 3    | 0         | 0         | 0.125    | 0           |
     * 4    | 1         | 0.03125   | 0.125    | 0.00390625 |
     * }</pre>
     *
     * err = sum of last column
     */
    public void testERRMissingRatings() {
        List<RatedDocument> rated = new ArrayList<>();
        Integer[] relevanceRatings = new Integer[] { 3, null, 0, 1 };
        SearchHit[] hits = createSearchHits(rated, relevanceRatings);
        ExpectedReciprocalRank err = new ExpectedReciprocalRank(3, null, 4);
        EvalQueryQuality evaluation = err.evaluate("id", hits, rated);
        assertEquals(0.875 + 0.00390625, evaluation.metricScore(), DELTA);
        assertEquals(1, ((ExpectedReciprocalRank.Detail) evaluation.getMetricDetails()).getUnratedDocs());
        // if we supply e.g. 2 as unknown docs rating, it should be the same as in the other test above
        err = new ExpectedReciprocalRank(3, 2, 4);
        assertEquals(0.8984375 + 0.00244140625, err.evaluate("id", hits, rated).metricScore(), DELTA);
    }

    private SearchHit[] createSearchHits(List<RatedDocument> rated, Integer[] relevanceRatings) {
        SearchHit[] hits = new SearchHit[relevanceRatings.length];
        for (int i = 0; i < relevanceRatings.length; i++) {
            if (relevanceRatings[i] != null) {
                rated.add(new RatedDocument("index", Integer.toString(i), relevanceRatings[i]));
            }
            hits[i] = new SearchHit(i, Integer.toString(i), Collections.emptyMap(), Collections.emptyMap());
            hits[i].shard(new SearchShardTarget("testnode", new ShardId("index", "uuid", 0), null, OriginalIndices.NONE));
        }
        return hits;
    }

    /**
     * test that metric returns 0.0 when there are no search results
     */
    public void testNoResults() throws Exception {
        ExpectedReciprocalRank err = new ExpectedReciprocalRank(5, 0, 10);
        assertEquals(0.0, err.evaluate("id", new SearchHit[0], Collections.emptyList()).metricScore(), DELTA);
    }

    public void testParseFromXContent() throws IOException {
        assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"maximum_relevance\": 5, \"k\" : 15 }", 2, 5, 15);
        assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"maximum_relevance\": 4 }", 2, 4, 10);
        assertParsedCorrect("{ \"maximum_relevance\": 4, \"k\": 23 }", null, 4, 23);
    }

    private void assertParsedCorrect(String xContent, Integer expectedUnknownDocRating, int expectedMaxRelevance, int expectedK)
        throws IOException {
        try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
            ExpectedReciprocalRank errAt = ExpectedReciprocalRank.fromXContent(parser);
            assertEquals(expectedUnknownDocRating, errAt.getUnknownDocRating());
            assertEquals(expectedK, errAt.getK());
            assertEquals(expectedMaxRelevance, errAt.getMaxRelevance());
        }
    }

    public static ExpectedReciprocalRank createTestItem() {
        Integer unknownDocRating = frequently() ? Integer.valueOf(randomIntBetween(0, 10)) : null;
        int maxRelevance = randomIntBetween(1, 10);
        return new ExpectedReciprocalRank(maxRelevance, unknownDocRating, randomIntBetween(1, 10));
    }

    public void testXContentRoundtrip() throws IOException {
        ExpectedReciprocalRank testItem = createTestItem();
        XContentBuilder builder = MediaTypeRegistry.contentBuilder(randomFrom(XContentType.values()));
        XContentBuilder shuffled = shuffleXContent(testItem.toXContent(builder, ToXContent.EMPTY_PARAMS));
        try (XContentParser itemParser = createParser(shuffled)) {
            itemParser.nextToken();
            itemParser.nextToken();
            ExpectedReciprocalRank parsedItem = ExpectedReciprocalRank.fromXContent(itemParser);
            assertNotSame(testItem, parsedItem);
            assertEquals(testItem, parsedItem);
            assertEquals(testItem.hashCode(), parsedItem.hashCode());
        }
    }

    public void testXContentParsingIsNotLenient() throws IOException {
        ExpectedReciprocalRank testItem = createTestItem();
        XContentType xContentType = randomFrom(XContentType.values());
        BytesReference originalBytes = toShuffledXContent(testItem, xContentType, ToXContent.EMPTY_PARAMS, randomBoolean());
        BytesReference withRandomFields = insertRandomFields(xContentType, originalBytes, null, random());
        try (XContentParser parser = createParser(xContentType.xContent(), withRandomFields)) {
            parser.nextToken();
            parser.nextToken();
            XContentParseException exception = expectThrows(
                XContentParseException.class,
                () -> DiscountedCumulativeGain.fromXContent(parser)
            );
            assertThat(exception.getMessage(), containsString("[dcg] unknown field"));
        }
    }

    public void testMetricDetails() {
        int unratedDocs = randomIntBetween(0, 100);
        ExpectedReciprocalRank.Detail detail = new ExpectedReciprocalRank.Detail(unratedDocs);
        assertEquals(unratedDocs, detail.getUnratedDocs());
    }

    public void testSerialization() throws IOException {
        ExpectedReciprocalRank original = createTestItem();
        ExpectedReciprocalRank deserialized = OpenSearchTestCase.copyWriteable(
            original,
            new NamedWriteableRegistry(Collections.emptyList()),
            ExpectedReciprocalRank::new
        );
        assertEquals(deserialized, original);
        assertEquals(deserialized.hashCode(), original.hashCode());
        assertNotSame(deserialized, original);
    }

    public void testEqualsAndHash() throws IOException {
        checkEqualsAndHashCode(createTestItem(), original -> {
            return new ExpectedReciprocalRank(original.getMaxRelevance(), original.getUnknownDocRating(), original.getK());
        }, ExpectedReciprocalRankTests::mutateTestItem);
    }

    private static ExpectedReciprocalRank mutateTestItem(ExpectedReciprocalRank original) {
        switch (randomIntBetween(0, 2)) {
            case 0:
                return new ExpectedReciprocalRank(original.getMaxRelevance() + 1, original.getUnknownDocRating(), original.getK());
            case 1:
                return new ExpectedReciprocalRank(
                    original.getMaxRelevance(),
                    randomValueOtherThan(original.getUnknownDocRating(), () -> randomIntBetween(0, 10)),
                    original.getK()
                );
            case 2:
                return new ExpectedReciprocalRank(
                    original.getMaxRelevance(),
                    original.getUnknownDocRating(),
                    randomValueOtherThan(original.getK(), () -> randomIntBetween(1, 10))
                );
            default:
                throw new IllegalArgumentException("mutation variant not allowed");
        }
    }
}
