"""Integration tests for token-based authentication in Ray."""

import os
import subprocess
import sys
from pathlib import Path
from typing import Optional

import pytest

import ray
import ray.dashboard.consts as dashboard_consts
from ray._common.network_utils import build_address
from ray._private.test_utils import (
    PrometheusTimeseries,
    client_test_enabled,
    fetch_prometheus_timeseries,
    wait_for_condition,
)

try:
    from ray._raylet import AuthenticationTokenLoader

    _RAYLET_AVAILABLE = True
except ImportError:
    _RAYLET_AVAILABLE = False
    AuthenticationTokenLoader = None

from ray._private.authentication_test_utils import (
    authentication_env_guard,
    clear_auth_token_sources,
    reset_auth_token_state,
    set_auth_mode,
    set_auth_token_path,
    set_env_auth_token,
)

pytestmark = pytest.mark.skipif(
    not _RAYLET_AVAILABLE,
    reason="Authentication tests require ray._raylet (not available in minimal installs)",
)


def _run_ray_start_and_verify_status(
    args: list, env: dict, expect_success: bool = True, timeout: int = 30
) -> subprocess.CompletedProcess:
    """Helper to run ray start command with proper error handling."""
    result = subprocess.run(
        ["ray", "start"] + args,
        env={"RAY_ENABLE_WINDOWS_OR_OSX_CLUSTER": "1", **env},
        capture_output=True,
        text=True,
        timeout=timeout,
    )

    if expect_success:
        assert result.returncode == 0, (
            f"ray start should have succeeded. "
            f"stdout: {result.stdout}, stderr: {result.stderr}"
        )
    else:
        assert result.returncode != 0, (
            f"ray start should have failed but succeeded. "
            f"stdout: {result.stdout}, stderr: {result.stderr}"
        )
        # Check that error message mentions token
        error_output = result.stdout + result.stderr
        assert (
            "authentication token" in error_output.lower()
            or "token" in error_output.lower()
        ), f"Error message should mention token. Got: {error_output}"

    return result


def _cleanup_ray_start(env: Optional[dict] = None):
    """Helper to clean up ray start processes."""
    # Ensure any ray.init() connection is closed first
    if ray.is_initialized():
        ray.shutdown()

    # Stop with a longer timeout
    subprocess.run(
        ["ray", "stop", "--force"],
        env=env,
        capture_output=True,
        timeout=60,  # Increased timeout for flaky cleanup
        check=False,  # Don't raise on non-zero exit
    )

    # Wait for ray processes to actually stop
    def ray_stopped():
        result = subprocess.run(
            ["ray", "status"],
            capture_output=True,
            check=False,
        )
        # ray status returns non-zero when no cluster is running
        return result.returncode != 0

    try:
        wait_for_condition(ray_stopped, timeout=10, retry_interval_ms=500)
    except Exception:
        # Best effort - don't fail the test if we can't verify it stopped
        pass


@pytest.fixture(autouse=True)
def clean_token_sources(cleanup_auth_token_env):
    """Ensure authentication-related state is clean around each test."""

    clear_auth_token_sources(remove_default=True)
    reset_auth_token_state()

    yield

    if ray.is_initialized():
        ray.shutdown()

    subprocess.run(
        ["ray", "stop", "--force"],
        capture_output=True,
        timeout=60,
        check=False,
    )

    reset_auth_token_state()


@pytest.mark.skipif(
    client_test_enabled(),
    reason="This test is for starting a new local cluster, not compatible with client mode",
)
def test_local_cluster_generates_token():
    """Test ray.init() generates token for local cluster when auth_mode=token is set."""
    # Ensure no token exists
    default_token_path = Path.home() / ".ray" / "auth_token"
    assert (
        not default_token_path.exists()
    ), f"Token file already exists at {default_token_path}"

    # Enable token auth via environment variable
    set_auth_mode("token")
    reset_auth_token_state()

    # Initialize Ray with token auth
    ray.init()

    try:
        # Verify token file was created
        assert default_token_path.exists(), (
            f"Token file was not created at {default_token_path}. "
            f"HOME={os.environ.get('HOME')}, "
            f"Files in {default_token_path.parent}: {list(default_token_path.parent.iterdir()) if default_token_path.parent.exists() else 'directory does not exist'}"
        )
        token = default_token_path.read_text().strip()
        assert len(token) == 64
        assert all(c in "0123456789abcdef" for c in token)

        # Verify cluster is working
        assert ray.is_initialized()

    finally:
        ray.shutdown()


def test_connect_without_token_raises_error(setup_cluster_with_token_auth):
    """Test ray.init(address=...) without token fails when auth_mode=token is set."""
    cluster_info = setup_cluster_with_token_auth
    cluster = cluster_info["cluster"]

    # Disconnect the current driver session and drop token state before retrying.
    ray.shutdown()
    set_auth_mode("disabled")
    clear_auth_token_sources(remove_default=True)
    reset_auth_token_state()

    # Ensure no token exists
    token_loader = AuthenticationTokenLoader.instance()
    assert not token_loader.has_token()

    # Try to connect to the cluster without a token - should raise RuntimeError
    with pytest.raises(ConnectionError):
        ray.init(address=cluster.address)


@pytest.mark.parametrize(
    "token,expected_status",
    [
        (None, 401),  # No token -> Unauthorized
        ("wrong_token", 403),  # Wrong token -> Forbidden
    ],
    ids=["no_token", "wrong_token"],
)
def test_state_api_auth_failure(token, expected_status, setup_cluster_with_token_auth):
    """Test that state API calls fail with missing or incorrect token."""
    import requests

    cluster_info = setup_cluster_with_token_auth
    dashboard_url = cluster_info["dashboard_url"]

    # Make direct HTTP request to state API endpoint
    headers = {}
    if token is not None:
        headers["Authorization"] = f"Bearer {token}"

    response = requests.get(f"{dashboard_url}/api/v0/actors", headers=headers)

    assert response.status_code == expected_status, (
        f"State API should return {expected_status}, got {response.status_code}: "
        f"{response.text}"
    )


@pytest.mark.parametrize("tokens_match", [True, False])
def test_cluster_token_authentication(tokens_match, setup_cluster_with_token_auth):
    """Test cluster authentication with matching and non-matching tokens."""
    cluster_info = setup_cluster_with_token_auth
    cluster = cluster_info["cluster"]
    cluster_token = cluster_info["token"]

    # Reconfigure the driver token state to simulate fresh connections.
    ray.shutdown()
    set_auth_mode("token")

    if tokens_match:
        client_token = cluster_token  # Same token - should succeed
    else:
        client_token = "b" * 64  # Different token - should fail

    set_env_auth_token(client_token)
    reset_auth_token_state()

    if tokens_match:
        # Should succeed - test gRPC calls work
        ray.init(address=cluster.address)

        obj_ref = ray.put("test_data")
        result = ray.get(obj_ref)
        assert result == "test_data"

        @ray.remote
        def test_func():
            return "success"

        result = ray.get(test_func.remote())
        assert result == "success"

        ray.shutdown()

    else:
        # Should fail - connection or gRPC calls should fail
        with pytest.raises((ConnectionError, RuntimeError)):
            ray.init(address=cluster.address)
            try:
                ray.put("test")
            finally:
                ray.shutdown()


@pytest.mark.skipif(
    client_test_enabled(),
    reason="Uses subprocess ray start, not compatible with client mode",
)
@pytest.mark.parametrize("is_head", [True, False])
def test_ray_start_without_token_raises_error(is_head, request):
    """Test that ray start fails when auth_mode=token but no token exists."""
    # Set up environment with token auth enabled but no token
    env = os.environ.copy()
    env["RAY_AUTH_MODE"] = "token"
    env.pop("RAY_AUTH_TOKEN", None)
    env.pop("RAY_AUTH_TOKEN_PATH", None)

    # Ensure no default token file exists (already cleaned by fixture)
    default_token_path = Path.home() / ".ray" / "auth_token"
    assert not default_token_path.exists()

    # When specifying an address, we need a head node to connect to
    cluster_info = None
    if not is_head:
        cluster_info = request.getfixturevalue("setup_cluster_with_token_auth")
        cluster = cluster_info["cluster"]
        ray.shutdown()

    # Prepare arguments
    if is_head:
        args = ["--head", "--port=0"]
    else:
        args = [f"--address={cluster.address}"]

    # Try to start node - should fail
    _run_ray_start_and_verify_status(args, env, expect_success=False)


@pytest.mark.skipif(
    client_test_enabled(),
    reason="Uses subprocess ray start, not compatible with client mode",
)
def test_ray_start_head_with_token_succeeds():
    """Test that ray start --head succeeds when token auth is enabled with a valid token."""
    # Set up environment with token auth and a valid token
    test_token = "a" * 64
    env = os.environ.copy()
    env["RAY_AUTH_TOKEN"] = test_token
    env["RAY_AUTH_MODE"] = "token"

    try:
        # Start head node - should succeed
        _run_ray_start_and_verify_status(
            ["--head", "--port=0"], env, expect_success=True
        )

        # Verify we can connect to the cluster with ray.init()
        set_env_auth_token(test_token)
        set_auth_mode("token")
        reset_auth_token_state()

        # Wait for cluster to be ready
        def cluster_ready():
            try:
                ray.init(address="auto")
                return True
            except Exception:
                return False

        wait_for_condition(cluster_ready, timeout=10)
        assert ray.is_initialized()

        # Test basic operations work
        @ray.remote
        def test_func():
            return "success"

        result = ray.get(test_func.remote())
        assert result == "success"

    finally:
        # Cleanup handles ray.shutdown() internally
        _cleanup_ray_start(env)


@pytest.mark.skipif(
    client_test_enabled(),
    reason="Uses subprocess ray start, not compatible with client mode",
)
@pytest.mark.parametrize("token_match", ["correct", "incorrect"])
def test_ray_start_address_with_token(token_match, setup_cluster_with_token_auth):
    """Test ray start --address=... with correct or incorrect token."""
    cluster_info = setup_cluster_with_token_auth
    cluster = cluster_info["cluster"]
    cluster_token = cluster_info["token"]

    # Reset the driver connection to reuse the fixture-backed cluster.
    ray.shutdown()
    set_auth_mode("token")

    # Set up environment for worker
    env = os.environ.copy()
    env["RAY_AUTH_MODE"] = "token"

    if token_match == "correct":
        env["RAY_AUTH_TOKEN"] = cluster_token
        expect_success = True
    else:
        env["RAY_AUTH_TOKEN"] = "b" * 64
        expect_success = False

    # Start worker node
    _run_ray_start_and_verify_status(
        [f"--address={cluster.address}", "--num-cpus=1"],
        env,
        expect_success=expect_success,
    )

    if token_match == "correct":
        try:
            # Connect and verify the cluster has 2 nodes (head + worker)
            set_env_auth_token(cluster_token)
            reset_auth_token_state()
            ray.init(address=cluster.address)

            def worker_joined():
                return len(ray.nodes()) >= 2

            wait_for_condition(worker_joined, timeout=10)

            nodes = ray.nodes()
            assert (
                len(nodes) >= 2
            ), f"Expected at least 2 nodes, got {len(nodes)}: {nodes}"

        finally:
            if ray.is_initialized():
                ray.shutdown()
            _cleanup_ray_start(env)


def test_e2e_operations_with_token_auth(setup_cluster_with_token_auth):
    """Test that e2e operations work with token authentication enabled.

    This verifies that with token auth enabled:
    1. Tasks execute successfully
    2. Actors can be created and called
    3. State API works (list_nodes, list_actors, list_tasks)
    4. Job submission works
    """
    cluster_info = setup_cluster_with_token_auth

    # Test 1: Submit a simple task
    @ray.remote
    def simple_task(x):
        return x + 1

    result = ray.get(simple_task.remote(41))
    assert result == 42, f"Task should return 42, got {result}"

    # Test 2: Create and use an actor
    @ray.remote
    class SimpleActor:
        def __init__(self):
            self.value = 0

        def increment(self):
            self.value += 1
            return self.value

    actor = SimpleActor.remote()
    result = ray.get(actor.increment.remote())
    assert result == 1, f"Actor method should return 1, got {result}"

    # Test 3: State API operations (uses HTTP with auth headers)
    from ray.util.state import list_actors, list_nodes, list_tasks

    # List nodes - should include at least the head node
    wait_for_condition(lambda: len(list_nodes()) >= 1)

    # List actors - should include our SimpleActor
    def check_actors():
        actors = list_actors()
        if len(actors) < 1:
            return False
        return "SimpleActor" in actors[0].class_name

    wait_for_condition(check_actors)

    # List tasks - should include completed tasks
    wait_for_condition(lambda: len(list_tasks()) >= 1)

    # Test 4: Submit a job and wait for completion
    from ray.job_submission import JobSubmissionClient

    # Create job submission client (uses HTTP with auth headers)
    client = JobSubmissionClient(address=cluster_info["dashboard_url"])

    # Submit a simple job
    job_id = client.submit_job(
        entrypoint="echo 'Hello from job'",
    )

    # Wait for job to complete
    def job_finished():
        status = client.get_job_status(job_id)
        return status in ["SUCCEEDED", "FAILED", "STOPPED"]

    wait_for_condition(job_finished, timeout=30)

    final_status = client.get_job_status(job_id)
    assert (
        final_status == "SUCCEEDED"
    ), f"Job should succeed, got status: {final_status}"


def test_logs_api_with_token_auth(setup_cluster_with_token_auth):
    """Test that log APIs work with token authentication enabled."""
    from ray.util.state import get_log, list_logs

    # Get node ID for log queries
    node_id = ray.nodes()[0]["NodeID"]

    # Test list_logs() with valid auth
    logs = list_logs(node_id=node_id)
    assert isinstance(logs, dict), f"list_logs should return a dict, got {type(logs)}"

    # Test get_log() with valid auth (fetch raylet.out which will always exist)
    chunks_received = 0
    for chunk in get_log(filename="raylet.out", node_id=node_id, tail=10):
        assert isinstance(chunk, str), f"get_log chunk should be str, got {type(chunk)}"
        chunks_received += 1
        break

    assert chunks_received > 0, "Should have received at least one log chunk"


@pytest.mark.skipif(
    client_test_enabled(),
    reason="Uses subprocess ray CLI, not compatible with client mode",
)
@pytest.mark.parametrize("use_generate", [True, False])
def test_get_auth_token_cli(use_generate):
    """Test ray get-auth-token CLI command."""
    test_token = "a" * 64

    with authentication_env_guard():
        if use_generate:
            # Test --generate flag (no token set)
            clear_auth_token_sources(remove_default=True)
            args = ["ray", "get-auth-token", "--generate"]
        else:
            # Test with existing token from env var
            set_env_auth_token(test_token)
            reset_auth_token_state()
            args = ["ray", "get-auth-token"]

        env = os.environ.copy()
        result = subprocess.run(
            args,
            env=env,
            capture_output=True,
            text=True,
            timeout=10,
        )

        assert result.returncode == 0, (
            f"ray get-auth-token should succeed. "
            f"stdout: {result.stdout}, stderr: {result.stderr}"
        )

        # Verify token is printed to stdout
        token = result.stdout.strip()
        assert len(token) == 64, token
        assert all(c in "0123456789abcdef" for c in token), "Token should be hex"

        if not use_generate:
            # When using env var, should get exact token back
            assert token == test_token

        # Verify logs went to stderr (if --generate was used)
        if use_generate:
            assert (
                "generating new authentication token..." in result.stderr.lower()
            ), "Should log generation to stderr"


@pytest.mark.skipif(
    client_test_enabled(),
    reason="Uses subprocess ray CLI, not compatible with client mode",
)
def test_get_auth_token_cli_no_token_no_generate():
    """Test ray get-auth-token fails without token and without --generate."""
    with authentication_env_guard():
        reset_auth_token_state()
        clear_auth_token_sources(remove_default=True)
        env = os.environ.copy()

        result = subprocess.run(
            ["ray", "get-auth-token"],
            env=env,
            capture_output=True,
            text=True,
            timeout=10,
        )

        assert result.returncode != 0, "Should fail when no token and no --generate"
        assert "error" in result.stderr.lower(), "Should print error to stderr"
        assert "no" in result.stderr.lower() and "token" in result.stderr.lower()


@pytest.mark.skipif(
    client_test_enabled(),
    reason="Uses subprocess ray CLI, not compatible with client mode",
)
def test_get_auth_token_cli_piping():
    """Test that ray get-auth-token output can be piped."""
    test_token = "b" * 64

    with authentication_env_guard():
        set_env_auth_token(test_token)
        reset_auth_token_state()
        env = os.environ.copy()

        # Test piping: use token in shell pipeline
        result = subprocess.run(
            "ray get-auth-token | wc -c",
            shell=True,
            env=env,
            capture_output=True,
            text=True,
            timeout=10,
        )

        assert result.returncode == 0
        char_count = int(result.stdout.strip())
        assert char_count == 64, f"Expected 64 chars (no newline), got {char_count}"


@pytest.mark.skipif(
    client_test_enabled(),
    reason="Tests AuthenticationTokenLoader directly, no benefit testing this in client mode",
)
def test_missing_token_file_raises_authentication_error():
    """Test that RAY_AUTH_TOKEN_PATH pointing to missing file raises AuthenticationError."""
    with authentication_env_guard():
        # Clear first, then set up the specific test scenario
        clear_auth_token_sources(remove_default=True)
        set_auth_mode("token")
        set_auth_token_path(None, "/nonexistent/path/to/token")
        reset_auth_token_state()

        token_loader = AuthenticationTokenLoader.instance()

        with pytest.raises(ray.exceptions.AuthenticationError) as exc_info:
            token_loader.has_token()

        # Verify error message is informative
        assert str(Path("/nonexistent/path/to/token")) in str(exc_info.value)
        assert "RAY_AUTH_TOKEN_PATH" in str(exc_info.value)


@pytest.mark.skipif(
    client_test_enabled(),
    reason="Tests AuthenticationTokenLoader directly, no benefit testing this in client mode",
)
def test_empty_token_file_raises_authentication_error(tmp_path):
    """Test that RAY_AUTH_TOKEN_PATH pointing to empty file raises AuthenticationError."""
    token_file = tmp_path / "empty_token_file.txt"
    with authentication_env_guard():
        # Clear first, then set up the specific test scenario
        clear_auth_token_sources(remove_default=True)
        set_auth_mode("token")
        set_auth_token_path("", token_file)
        reset_auth_token_state()

        token_loader = AuthenticationTokenLoader.instance()

        with pytest.raises(ray.exceptions.AuthenticationError) as exc_info:
            token_loader.has_token()

        assert "cannot be opened or is empty" in str(exc_info.value)
        assert str(token_file) in str(exc_info.value)


@pytest.mark.skipif(
    client_test_enabled(),
    reason="Tests AuthenticationTokenLoader directly, no benefit testing this in client mode",
)
def test_no_token_with_auth_enabled_returns_false():
    """Test that has_token(ignore_auth_mode=True) returns False when no token exists.

    This allows the caller (ensure_token_if_auth_enabled) to decide whether
    to generate a new token or raise an error.
    """
    with authentication_env_guard():
        set_auth_mode("token")
        clear_auth_token_sources(remove_default=True)
        reset_auth_token_state()

        token_loader = AuthenticationTokenLoader.instance()

        # has_token(ignore_auth_mode=True) should return False, not raise an exception
        result = token_loader.has_token(ignore_auth_mode=True)
        assert result is False


@pytest.mark.skipif(
    client_test_enabled(),
    reason="no benefit testing this in client mode",
)
def test_opentelemetry_metrics_with_token_auth(setup_cluster_with_token_auth):
    """Test that OpenTelemetry metrics are exported with token authentication.

    This test verifies that the C++ OpenTelemetryMetricRecorder correctly includes
    the authentication token in its gRPC metadata when exporting metrics to the
    metrics agent. If the auth headers are missing or incorrect, the metrics agent
    would reject the requests and metrics wouldn't be collected.
    """

    cluster_info = setup_cluster_with_token_auth
    cluster = cluster_info["cluster"]

    # Get the metrics export address from the head node
    head_node = cluster.head_node
    prom_addresses = [
        build_address(head_node.node_ip_address, head_node.metrics_export_port)
    ]

    timeseries = PrometheusTimeseries()

    def verify_metrics_collected():
        """Verify that metrics are being exported successfully."""
        fetch_prometheus_timeseries(prom_addresses, timeseries)
        metric_names = list(timeseries.metric_descriptors.keys())

        # Check for core Ray metrics that are always exported
        # These metrics are exported via the C++ OpenTelemetry recorder
        expected_metrics = [
            "ray_node_cpu_utilization",
            "ray_node_mem_used",
            "ray_node_disk_usage",
        ]

        # At least some metrics should be present
        return len(metric_names) > 0 and any(
            any(expected in name for name in metric_names)
            for expected in expected_metrics
        )

    # Wait for metrics to be collected
    # If auth wasn't working, the metrics agent would reject the exports
    # and we wouldn't see any metrics
    wait_for_condition(verify_metrics_collected, retry_interval_ms=1000)


def _get_dashboard_agent_address(cluster_info):
    """Get the dashboard agent HTTP address from a running cluster."""
    import json

    # Get agent address from internal KV
    node_id = ray.nodes()[0]["NodeID"]
    key = f"{dashboard_consts.DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{node_id}"
    agent_addr = ray.experimental.internal_kv._internal_kv_get(
        key, namespace=ray._private.ray_constants.KV_NAMESPACE_DASHBOARD
    )
    if agent_addr:
        ip, http_port, grpc_port = json.loads(agent_addr)
        return f"http://{ip}:{http_port}"
    return None


def _wait_and_get_dashboard_agent_address(cluster_info, timeout=30):
    """Waits for the dashboard agent address to become available and returns it."""

    def agent_address_is_available():
        return _get_dashboard_agent_address(cluster_info) is not None

    wait_for_condition(agent_address_is_available, timeout=timeout)
    return _get_dashboard_agent_address(cluster_info)


@pytest.mark.parametrize(
    "token_type,expected_status",
    [
        ("none", 401),  # No token -> Unauthorized
        ("valid", "not_auth_error"),  # Valid token -> passes auth (may get 404)
        ("invalid", 403),  # Invalid token -> Forbidden
    ],
    ids=["no_token", "valid_token", "invalid_token"],
)
def test_dashboard_agent_auth(
    token_type, expected_status, setup_cluster_with_token_auth
):
    """Test dashboard agent authentication with various token scenarios."""
    import requests

    cluster_info = setup_cluster_with_token_auth

    agent_address = _wait_and_get_dashboard_agent_address(cluster_info)

    # Build headers based on token type
    headers = {}
    if token_type == "valid":
        headers["Authorization"] = f"Bearer {cluster_info['token']}"
    elif token_type == "invalid":
        headers["Authorization"] = "Bearer invalid_token_12345678901234567890"
    # token_type == "none" -> no Authorization header

    response = requests.get(
        f"{agent_address}/api/job_agent/jobs/nonexistent/logs",
        headers=headers,
        timeout=5,
    )

    if expected_status == "not_auth_error":
        # Valid token should pass auth (may get 404 for nonexistent job)
        assert response.status_code not in (401, 403), (
            f"Valid token should be accepted, got {response.status_code}: "
            f"{response.text}"
        )
    else:
        assert (
            response.status_code == expected_status
        ), f"Expected {expected_status}, got {response.status_code}: {response.text}"


@pytest.mark.parametrize(
    "endpoint",
    ["/api/healthz", "/api/local_raylet_healthz"],
    ids=["healthz", "local_raylet_healthz"],
)
def test_dashboard_agent_health_check_public(endpoint, setup_cluster_with_token_auth):
    """Test that agent health check endpoints remain public without auth."""
    import requests

    cluster_info = setup_cluster_with_token_auth

    agent_address = _wait_and_get_dashboard_agent_address(cluster_info)

    # Health check endpoints should be accessible without auth
    response = requests.get(f"{agent_address}{endpoint}", timeout=5)
    assert response.status_code == 200, (
        f"Health check {endpoint} should return 200 without auth, "
        f"got {response.status_code}: {response.text}"
    )


if __name__ == "__main__":
    sys.exit(pytest.main(["-vv", __file__]))
