/*
 * SPDX-License-Identifier: Apache-2.0
 *
 * The OpenSearch Contributors require contributions made to
 * this file be licensed under the Apache-2.0 license or a
 * compatible open source license.
 */

package org.opensearch.transport.grpc.ssl;

import org.opensearch.common.network.NetworkService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.common.transport.TransportAddress;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ExecutorBuilder;
import org.opensearch.threadpool.FixedExecutorBuilder;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.grpc.interceptor.GrpcInterceptorChain;
import org.junit.After;
import org.junit.Before;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import io.grpc.BindableService;
import io.grpc.ServerInterceptor;
import io.grpc.StatusRuntimeException;
import io.grpc.health.v1.HealthCheckResponse;

import static org.opensearch.transport.grpc.ssl.SecureSettingsHelpers.ConnectExceptions.BAD_CERT;
import static org.opensearch.transport.grpc.ssl.SecureSettingsHelpers.getServerClientAuthNone;
import static org.opensearch.transport.grpc.ssl.SecureSettingsHelpers.getServerClientAuthOptional;
import static org.opensearch.transport.grpc.ssl.SecureSettingsHelpers.getServerClientAuthRequired;

public class SecureNetty4GrpcServerTransportTests extends OpenSearchTestCase {
    private NetworkService networkService;
    private ThreadPool threadPool;
    private final List<BindableService> services = new ArrayList<>();

    private ServerInterceptor serverInterceptor;

    static Settings createSettings() {
        return Settings.builder().put(SecureNetty4GrpcServerTransport.SETTING_GRPC_PORT.getKey(), getPortRange()).build();
    }

    @Before
    public void setup() {
        networkService = new NetworkService(Collections.emptyList());

        // Create a ThreadPool with the gRPC executor
        Settings settings = Settings.builder().put("node.name", "test-node").put("grpc.netty.executor_count", 4).build();
        ExecutorBuilder<?> grpcExecutorBuilder = new FixedExecutorBuilder(settings, "grpc", 4, 1000, "thread_pool.grpc");
        threadPool = new ThreadPool(settings, grpcExecutorBuilder);
        serverInterceptor = new GrpcInterceptorChain(threadPool.getThreadContext(), Collections.emptyList());
    }

    @After
    public void shutdown() {
        if (threadPool != null) {
            threadPool.shutdown();
        }
        networkService = null;
    }

    public void testGrpcSecureTransportStartStop() {
        try (
            SecureNetty4GrpcServerTransport transport = new SecureNetty4GrpcServerTransport(
                createSettings(),
                services,
                networkService,
                threadPool,
                getServerClientAuthNone(),
                serverInterceptor
            )
        ) {
            transport.start();
            assertTrue(transport.getBoundAddress().boundAddresses().length > 0);
            assertNotNull(transport.getBoundAddress().publishAddress().address());
            transport.stop();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void testGrpcInsecureAuthTLS() {
        try (
            SecureNetty4GrpcServerTransport transport = new SecureNetty4GrpcServerTransport(
                createSettings(),
                services,
                networkService,
                threadPool,
                getServerClientAuthNone(),
                serverInterceptor
            )
        ) {
            transport.start();
            assertTrue(transport.getBoundAddress().boundAddresses().length > 0);
            assertNotNull(transport.getBoundAddress().publishAddress().address());
            final TransportAddress remoteAddress = randomFrom(transport.getBoundAddress().boundAddresses());

            // Client without cert
            NettyGrpcClient client = new NettyGrpcClient.Builder().setAddress(remoteAddress).insecure(true).build();
            assertEquals(client.checkHealth(), HealthCheckResponse.ServingStatus.SERVING);
            client.close();

            transport.stop();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void testGrpcOptionalAuthTLS() {
        try (
            SecureNetty4GrpcServerTransport transport = new SecureNetty4GrpcServerTransport(
                createSettings(),
                services,
                networkService,
                threadPool,
                getServerClientAuthOptional(),
                serverInterceptor
            )
        ) {
            transport.start();
            assertTrue(transport.getBoundAddress().boundAddresses().length > 0);
            assertNotNull(transport.getBoundAddress().publishAddress().address());
            final TransportAddress remoteAddress = randomFrom(transport.getBoundAddress().boundAddresses());

            // Client without cert
            NettyGrpcClient hasNoCertClient = new NettyGrpcClient.Builder().setAddress(remoteAddress).insecure(true).build();
            assertEquals(hasNoCertClient.checkHealth(), HealthCheckResponse.ServingStatus.SERVING);
            hasNoCertClient.close();

            // Client with trusted cert
            NettyGrpcClient hasTrustedCertClient = new NettyGrpcClient.Builder().setAddress(remoteAddress).clientAuth(true).build();
            assertEquals(hasTrustedCertClient.checkHealth(), HealthCheckResponse.ServingStatus.SERVING);
            hasTrustedCertClient.close();

            transport.stop();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void testGrpcRequiredAuthTLS() {
        try (
            SecureNetty4GrpcServerTransport transport = new SecureNetty4GrpcServerTransport(
                createSettings(),
                services,
                networkService,
                threadPool,
                getServerClientAuthRequired(),
                serverInterceptor
            )
        ) {
            transport.start();
            assertTrue(transport.getBoundAddress().boundAddresses().length > 0);
            assertNotNull(transport.getBoundAddress().publishAddress().address());
            final TransportAddress remoteAddress = randomFrom(transport.getBoundAddress().boundAddresses());

            // Client without cert
            NettyGrpcClient hasNoCertClient = new NettyGrpcClient.Builder().setAddress(remoteAddress).insecure(true).build();
            assertThrows(StatusRuntimeException.class, hasNoCertClient::checkHealth);
            try {
                hasNoCertClient.checkHealth();
            } catch (Exception e) {
                assertEquals(SecureSettingsHelpers.ConnectExceptions.get(e), BAD_CERT);
            }
            hasNoCertClient.close();

            // Client with trusted cert
            NettyGrpcClient hasTrustedCertClient = new NettyGrpcClient.Builder().setAddress(remoteAddress).clientAuth(true).build();
            assertEquals(hasTrustedCertClient.checkHealth(), HealthCheckResponse.ServingStatus.SERVING);
            hasTrustedCertClient.close();

            transport.stop();
        } catch (Throwable e) {
            throw new RuntimeException(e);
        }
    }
}
