/*
 * SPDX-License-Identifier: Apache-2.0
 *
 * The OpenSearch Contributors require contributions made to
 * this file be licensed under the Apache-2.0 license or a
 * compatible open source license.
 */

package org.opensearch.indices.replication;

import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.OpenSearchAllocationTestCase.ShardAllocations;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.routing.IndexRoutingTable;
import org.opensearch.cluster.routing.RoutingNode;
import org.opensearch.cluster.routing.RoutingNodes;
import org.opensearch.cluster.routing.ShardRouting;
import org.opensearch.cluster.routing.allocation.allocator.BalancedShardsAllocator;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.index.IndexModule;
import org.opensearch.indices.replication.common.ReplicationType;
import org.opensearch.test.InternalTestCluster;
import org.opensearch.test.OpenSearchIntegTestCase;
import org.opensearch.test.junit.annotations.TestLogging;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import static org.opensearch.cluster.routing.ShardRoutingState.STARTED;
import static org.opensearch.cluster.routing.allocation.allocator.BalancedShardsAllocator.PREFER_PRIMARY_SHARD_BALANCE;
import static org.opensearch.cluster.routing.allocation.allocator.BalancedShardsAllocator.PREFER_PRIMARY_SHARD_REBALANCE;
import static org.opensearch.cluster.routing.allocation.allocator.BalancedShardsAllocator.PRIMARY_SHARD_REBALANCE_BUFFER;
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked;

@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 0)
public class SegmentReplicationAllocationIT extends SegmentReplicationBaseIT {

    private void createIndex(String idxName, int shardCount, int replicaCount, boolean isSegRep) {
        Settings.Builder builder = Settings.builder()
            .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, shardCount)
            .put(IndexModule.INDEX_QUERY_CACHE_ENABLED_SETTING.getKey(), false)
            .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, replicaCount);
        if (isSegRep) {
            builder = builder.put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT);
        } else {
            builder = builder.put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.DOCUMENT);
        }
        prepareCreate(idxName, builder).get();
    }

    public void enablePreferPrimaryBalance() {
        assertAcked(
            client().admin()
                .cluster()
                .prepareUpdateSettings()
                .setPersistentSettings(Settings.builder().put(BalancedShardsAllocator.PREFER_PRIMARY_SHARD_BALANCE.getKey(), "true"))
        );
    }

    public void setAllocationRelocationStrategy(boolean preferPrimaryBalance, boolean preferPrimaryRebalance, float buffer) {
        assertAcked(
            client().admin()
                .cluster()
                .prepareUpdateSettings()
                .setPersistentSettings(
                    Settings.builder()
                        .put(PREFER_PRIMARY_SHARD_BALANCE.getKey(), preferPrimaryBalance)
                        .put(PREFER_PRIMARY_SHARD_REBALANCE.getKey(), preferPrimaryRebalance)
                        .put(PRIMARY_SHARD_REBALANCE_BUFFER.getKey(), buffer)
                )
        );
    }

    /**
     * This test verifies that the overall primary balance is attained during allocation. This test verifies primary
     * balance per index and across all indices is maintained.
     */
    public void testGlobalPrimaryAllocation() throws Exception {
        internalCluster().startClusterManagerOnlyNode();
        final int maxReplicaCount = 1;
        final int maxShardCount = 1;
        final int nodeCount = randomIntBetween(maxReplicaCount + 1, 10);
        final int numberOfIndices = randomIntBetween(5, 10);

        final List<String> nodeNames = new ArrayList<>();
        logger.info("--> Creating {} nodes", nodeCount);
        for (int i = 0; i < nodeCount; i++) {
            nodeNames.add(internalCluster().startNode());
        }
        enablePreferPrimaryBalance();
        int shardCount, replicaCount;
        ClusterState state;
        for (int i = 0; i < numberOfIndices; i++) {
            shardCount = randomIntBetween(1, maxShardCount);
            replicaCount = randomIntBetween(0, maxReplicaCount);
            createIndex("test" + i, shardCount, replicaCount, i % 2 == 0);
            logger.info("--> Creating index {} with shard count {} and replica count {}", "test" + i, shardCount, replicaCount);
            ensureGreen(TimeValue.timeValueSeconds(60));
        }
        state = client().admin().cluster().prepareState().execute().actionGet().getState();
        logger.info(ShardAllocations.printShardDistribution(state));
        verifyPerIndexPrimaryBalance();
        verifyPrimaryBalance(0.0f);
    }

    /**
     * This test verifies the happy path where primary shard allocation is balanced when multiple indices are created.
     * <p>
     * This test in general passes without primary shard balance as well due to nature of allocation algorithm which
     * assigns all primary shards first followed by replica copies.
     */
    public void testPerIndexPrimaryAllocation() throws Exception {
        internalCluster().startClusterManagerOnlyNode();
        final int maxReplicaCount = 2;
        final int maxShardCount = 5;
        final int nodeCount = randomIntBetween(maxReplicaCount + 1, 10);
        final int numberOfIndices = randomIntBetween(5, 10);

        final List<String> nodeNames = new ArrayList<>();
        logger.info("--> Creating {} nodes", nodeCount);
        for (int i = 0; i < nodeCount; i++) {
            nodeNames.add(internalCluster().startNode());
        }
        enablePreferPrimaryBalance();
        int shardCount, replicaCount;
        ClusterState state;
        for (int i = 0; i < numberOfIndices; i++) {
            shardCount = randomIntBetween(1, maxShardCount);
            replicaCount = randomIntBetween(0, maxReplicaCount);
            createIndex("test" + i, shardCount, replicaCount, i % 2 == 0);
            logger.info("--> Creating index {} with shard count {} and replica count {}", "test" + i, shardCount, replicaCount);
            ensureGreen(TimeValue.timeValueSeconds(60));
            state = client().admin().cluster().prepareState().execute().actionGet().getState();
            logger.info(ShardAllocations.printShardDistribution(state));
        }
        verifyPerIndexPrimaryBalance();
    }

    /**
     * This test verifies balanced primary shard allocation for a single index with large shard count in event of node
     * going down and a new node joining the cluster. The results in shard distribution skewness and re-balancing logic
     * ensures the primary shard distribution is balanced.
     *
     */
    @TestLogging(reason = "Enable debug logs from cluster and index replication package", value = "org.opensearch.cluster:DEBUG,org.opensearch.indices.replication:DEBUG")
    public void testSingleIndexShardAllocation() throws Exception {
        internalCluster().startClusterManagerOnlyNode();
        final int maxReplicaCount = 1;
        final int maxShardCount = 50;
        final int nodeCount = 5;

        final List<String> nodeNames = new ArrayList<>();
        logger.info("--> Creating {} nodes", nodeCount);
        for (int i = 0; i < nodeCount; i++) {
            nodeNames.add(internalCluster().startNode());
        }
        enablePreferPrimaryBalance();

        // Modify other configurations, expecting that the primary balance strategy will not be affected.
        assertAcked(
            client().admin()
                .cluster()
                .prepareUpdateSettings()
                .setPersistentSettings(Settings.builder().put(PRIMARY_SHARD_REBALANCE_BUFFER.getKey(), 0.2))
        );

        ClusterState state;
        createIndex("test", maxShardCount, maxReplicaCount, true);
        ensureGreen(TimeValue.timeValueSeconds(60));
        state = client().admin().cluster().prepareState().execute().actionGet().getState();
        logger.info(ShardAllocations.printShardDistribution(state));
        verifyPerIndexPrimaryBalance();

        // Remove a node
        internalCluster().stopRandomNode(InternalTestCluster.nameFilter(nodeNames.get(0)));
        internalCluster().validateClusterFormed();
        ensureGreen(TimeValue.timeValueSeconds(100));
        state = client().admin().cluster().prepareState().execute().actionGet().getState();
        logger.info(ShardAllocations.printShardDistribution(state));
        verifyPerIndexPrimaryBalance();

        // Add a new node
        internalCluster().startDataOnlyNode();
        internalCluster().validateClusterFormed();
        ensureGreen(TimeValue.timeValueSeconds(100));
        state = client().admin().cluster().prepareState().execute().actionGet().getState();
        logger.info(ShardAllocations.printShardDistribution(state));
        verifyPerIndexPrimaryBalance();
    }

    /**
     * Similar to testSingleIndexShardAllocation test but creates multiple indices, multiple nodes adding in and getting
     * removed. The test asserts post each such event that primary shard distribution is balanced for each index.
     */
    public void testAllocationWithDisruption() throws Exception {
        internalCluster().startClusterManagerOnlyNode();
        final int maxReplicaCount = 2;
        final int maxShardCount = 2;
        // Create higher number of nodes than number of shards to reduce chances of SameShardAllocationDecider kicking-in
        // and preventing primary relocations
        final int nodeCount = randomIntBetween(5, 10);
        final int numberOfIndices = randomIntBetween(1, 10);

        logger.info("--> Creating {} nodes", nodeCount);
        final List<String> nodeNames = new ArrayList<>();
        for (int i = 0; i < nodeCount; i++) {
            nodeNames.add(internalCluster().startNode());
        }
        enablePreferPrimaryBalance();

        int shardCount, replicaCount;
        ClusterState state;
        for (int i = 0; i < numberOfIndices; i++) {
            shardCount = randomIntBetween(1, maxShardCount);
            replicaCount = randomIntBetween(1, maxReplicaCount);
            logger.info("--> Creating index test{} with primary {} and replica {}", i, shardCount, replicaCount);
            createIndex("test" + i, shardCount, replicaCount, i % 2 == 0);
            ensureGreen(TimeValue.timeValueSeconds(60));
            if (logger.isTraceEnabled()) {
                state = client().admin().cluster().prepareState().execute().actionGet().getState();
                logger.info(ShardAllocations.printShardDistribution(state));
            }
        }
        state = client().admin().cluster().prepareState().execute().actionGet().getState();
        logger.info(ShardAllocations.printShardDistribution(state));
        verifyPerIndexPrimaryBalance();

        final int additionalNodeCount = randomIntBetween(1, 5);
        logger.info("--> Adding {} nodes", additionalNodeCount);

        internalCluster().startNodes(additionalNodeCount);
        ensureGreen(TimeValue.timeValueSeconds(60));
        state = client().admin().cluster().prepareState().execute().actionGet().getState();
        logger.info(ShardAllocations.printShardDistribution(state));
        verifyPerIndexPrimaryBalance();

        int nodeCountToStop = additionalNodeCount;
        while (nodeCountToStop > 0) {
            internalCluster().stopRandomDataNode();
            // give replica a chance to promote as primary before terminating node containing the replica
            ensureGreen(TimeValue.timeValueSeconds(60));
            nodeCountToStop--;
        }
        state = client().admin().cluster().prepareState().execute().actionGet().getState();
        logger.info("--> Cluster state post nodes stop {}", state);
        logger.info(ShardAllocations.printShardDistribution(state));
        verifyPerIndexPrimaryBalance();
    }

    /**
     * Similar to testSingleIndexShardAllocation test but creates multiple indices, multiple nodes adding in and getting
     * removed. The test asserts post each such event that primary shard distribution is balanced for each index as well as across the nodes
     * when the PREFER_PRIMARY_SHARD_REBALANCE is set to true
     */
    public void testAllocationAndRebalanceWithDisruption() throws Exception {
        internalCluster().startClusterManagerOnlyNode();
        final int maxReplicaCount = 2;
        final int maxShardCount = 2;
        final int numberOfIndices = randomIntBetween(1, 3);
        final int maxPossibleShards = numberOfIndices * maxShardCount * (1 + maxReplicaCount);

        List<List<Integer>> shardAndReplicaCounts = new ArrayList<>();
        int shardCount, replicaCount, totalShards = 0;
        for (int i = 0; i < numberOfIndices; i++) {
            shardCount = randomIntBetween(1, maxShardCount);
            replicaCount = randomIntBetween(1, maxReplicaCount);
            shardAndReplicaCounts.add(Arrays.asList(shardCount, replicaCount));
            totalShards += shardCount * (1 + replicaCount);
        }
        // Create a strictly higher number of nodes than the number of shards to reduce chances of SameShardAllocationDecider kicking-in
        // and preventing primary relocations
        final int nodeCount = randomIntBetween(totalShards, maxPossibleShards) + 1;
        final float buffer = randomIntBetween(1, 4) * 0.10f;
        logger.info("--> Creating {} nodes", nodeCount);
        final List<String> nodeNames = new ArrayList<>();
        for (int i = 0; i < nodeCount; i++) {
            nodeNames.add(internalCluster().startNode());
        }
        setAllocationRelocationStrategy(true, true, buffer);

        ClusterState state;
        for (int i = 0; i < numberOfIndices; i++) {
            shardCount = shardAndReplicaCounts.get(i).get(0);
            replicaCount = shardAndReplicaCounts.get(i).get(1);
            logger.info("--> Creating index test{} with primary {} and replica {}", i, shardCount, replicaCount);
            createIndex("test" + i, shardCount, replicaCount, i % 2 == 0);
            ensureGreen(TimeValue.timeValueSeconds(60));
            if (logger.isTraceEnabled()) {
                state = client().admin().cluster().prepareState().execute().actionGet().getState();
                logger.info(ShardAllocations.printShardDistribution(state));
            }
        }
        state = client().admin().cluster().prepareState().execute().actionGet().getState();
        logger.info(ShardAllocations.printShardDistribution(state));
        verifyPerIndexPrimaryBalance();
        verifyPrimaryBalance(buffer);

        final int additionalNodeCount = randomIntBetween(1, 5);
        logger.info("--> Adding {} nodes", additionalNodeCount);

        internalCluster().startNodes(additionalNodeCount);
        ensureGreen(TimeValue.timeValueSeconds(60));
        state = client().admin().cluster().prepareState().execute().actionGet().getState();
        logger.info(ShardAllocations.printShardDistribution(state));
        verifyPerIndexPrimaryBalance();
        verifyPrimaryBalance(buffer);

        int nodeCountToStop = additionalNodeCount;
        while (nodeCountToStop > 0) {
            internalCluster().stopRandomDataNode();
            // give replica a chance to promote as primary before terminating node containing the replica
            ensureGreen(TimeValue.timeValueSeconds(60));
            nodeCountToStop--;
        }
        state = client().admin().cluster().prepareState().execute().actionGet().getState();
        logger.info("--> Cluster state post nodes stop {}", state);
        logger.info(ShardAllocations.printShardDistribution(state));
        verifyPerIndexPrimaryBalance();
        verifyPrimaryBalance(buffer);
    }

    /**
     * Utility method which ensures cluster has balanced primary shard distribution across a single index.
     * @throws Exception exception
     */
    private void verifyPerIndexPrimaryBalance() throws Exception {
        assertBusy(() -> {
            final ClusterState currentState = client().admin().cluster().prepareState().execute().actionGet().getState();
            RoutingNodes nodes = currentState.getRoutingNodes();
            for (final Map.Entry<String, IndexRoutingTable> index : currentState.getRoutingTable().indicesRouting().entrySet()) {
                final int totalPrimaryShards = index.getValue().primaryShardsActive();
                final int lowerBoundPrimaryShardsPerNode = (int) Math.floor(totalPrimaryShards * 1f / currentState.getRoutingNodes().size())
                    - 1;
                final int upperBoundPrimaryShardsPerNode = (int) Math.ceil(totalPrimaryShards * 1f / currentState.getRoutingNodes().size())
                    + 1;
                for (RoutingNode node : nodes) {
                    final int primaryCount = node.shardsWithState(index.getKey(), STARTED)
                        .stream()
                        .filter(ShardRouting::primary)
                        .collect(Collectors.toList())
                        .size();
                    // Asserts value is within the variance threshold (-1/+1 of the average value).
                    assertTrue(
                        "--> Primary balance assertion failure for index "
                            + index
                            + "on node "
                            + node.node().getName()
                            + " "
                            + lowerBoundPrimaryShardsPerNode
                            + " <= "
                            + primaryCount
                            + " (assigned) <= "
                            + upperBoundPrimaryShardsPerNode,
                        lowerBoundPrimaryShardsPerNode <= primaryCount && primaryCount <= upperBoundPrimaryShardsPerNode
                    );
                }
            }
        }, 60, TimeUnit.SECONDS);
    }

    private void verifyPrimaryBalance(float buffer) throws Exception {
        assertBusy(() -> {
            final ClusterState currentState = client().admin().cluster().prepareState().execute().actionGet().getState();
            RoutingNodes nodes = currentState.getRoutingNodes();
            int totalPrimaryShards = 0;
            for (final IndexRoutingTable index : currentState.getRoutingTable().indicesRouting().values()) {
                totalPrimaryShards += index.primaryShardsActive();
            }
            final int avgPrimaryShardsPerNode = (int) Math.ceil(totalPrimaryShards * 1f / currentState.getRoutingNodes().size());
            for (RoutingNode node : nodes) {
                final int primaryCount = node.shardsWithState(STARTED)
                    .stream()
                    .filter(ShardRouting::primary)
                    .collect(Collectors.toList())
                    .size();
                assertTrue(primaryCount <= (avgPrimaryShardsPerNode * (1 + buffer)));
            }
        }, 60, TimeUnit.SECONDS);
    }
}
