# Copyright 2026 Marimo. All rights reserved.
from __future__ import annotations

import os
import random
import tempfile
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING
from unittest.mock import Mock, patch

import msgspec
import pytest

from marimo._utils.platform import is_windows
from tests._server.conftest import get_session_manager
from tests._server.mocks import (
    token_header,
    with_read_session,
    with_session,
    with_websocket_session,
)
from tests.mocks import EDGE_CASE_FILENAMES
from tests.utils import try_assert_n_times

if TYPE_CHECKING:
    from starlette.testclient import TestClient, WebSocketTestSession

SESSION_ID = "session-123"
HEADERS = {
    "Marimo-Session-Id": SESSION_ID,
    **token_header("fake-token"),
}


@with_session(SESSION_ID)
def test_rename(client: TestClient) -> None:
    current_filename = get_session_manager(
        client
    ).file_router.get_unique_file_key()

    assert current_filename
    current_path = Path(current_filename)
    assert current_path.exists()

    directory = current_path.parent
    random_name = random.randint(0, 100000)
    new_path = directory / f"test_{random_name}.py"

    response = client.post(
        "/api/kernel/rename",
        headers=HEADERS,
        json={
            "filename": str(new_path),
        },
    )
    assert response.json() == {"success": True}

    def _new_path_exists():
        assert new_path.exists()

    try_assert_n_times(5, _new_path_exists)


@pytest.mark.flaky(reruns=5)
@with_session(SESSION_ID)
def test_read_code(client: TestClient) -> None:
    response = client.post(
        "/api/kernel/read_code",
        headers=HEADERS,
        json={},
    )
    assert response.status_code == 200, response.text
    assert "import marimo" in response.json()["contents"]


@with_read_session(SESSION_ID, include_code=True)
def test_read_code_in_run_mode_with_include_code(client: TestClient) -> None:
    """Test read_code works in run mode when include_code=True."""
    response = client.post(
        "/api/kernel/read_code",
        headers=HEADERS,
        json={},
    )
    assert response.status_code == 200, response.text
    assert "import marimo" in response.json()["contents"]


@with_read_session(SESSION_ID, include_code=False)
def test_read_code_in_run_mode_without_include_code(
    client: TestClient,
) -> None:
    """Test read_code is not accessible in run mode when include_code=False."""
    response = client.post(
        "/api/kernel/read_code",
        headers=HEADERS,
        json={},
    )
    # Should be denied 401 (unauthorized) or 403 (forbidden)
    assert response.status_code in [401, 403]


@pytest.mark.flaky(reruns=5)
@with_session(SESSION_ID)
def test_save_file(client: TestClient) -> None:
    filename = get_session_manager(client).file_router.get_unique_file_key()
    assert filename
    path = Path(filename)

    response = client.post(
        "/api/kernel/save",
        headers=HEADERS,
        json={
            "cellIds": ["1"],
            "filename": str(path),
            "codes": ["import marimo as mo"],
            "names": ["my_cell"],
            "configs": [
                {
                    "hide_code": True,
                    "disabled": False,
                }
            ],
        },
    )
    assert response.status_code == 200, response.text

    def _assert_contents():
        file_contents = path.read_text()
        assert "import marimo" in response.text
        assert "import marimo as mo" in file_contents
        assert "@app.cell(hide_code=True)" in file_contents
        assert "my_cell" in file_contents

    try_assert_n_times(5, _assert_contents)

    # save back
    response = client.post(
        "/api/kernel/save",
        headers=HEADERS,
        json={
            "cellIds": ["1"],
            "filename": str(path),
            "codes": ["import marimo as mo"],
            "names": ["__"],
            "configs": [
                {
                    "hide_code": False,
                }
            ],
        },
    )


@pytest.mark.xfail(
    reason="Flaky in CI, can't repro locally",
)
@with_session(SESSION_ID)
def test_save_with_header(client: TestClient) -> None:
    filename = get_session_manager(client).file_router.get_unique_file_key()
    assert filename
    path = Path(filename)
    assert path.exists()

    copyright_year = datetime.now().year
    header = (
        '"""This is a docstring"""\n\n'
        + f"# Copyright {copyright_year}\n# Linter ignore\n"
    )
    # Prepend a header to the file
    contents = path.read_text()
    contents = header + contents
    path.write_text(contents, encoding="UTF-8")

    response = client.post(
        "/api/kernel/save",
        headers=HEADERS,
        json={
            "cellIds": ["1"],
            "filename": str(path),
            "codes": ["import marimo as mo"],
            "names": ["my_cell"],
            "configs": [
                {
                    "hide_code": True,
                    "disabled": False,
                }
            ],
        },
    )

    assert response.status_code == 200, response.text
    assert "import marimo" in response.text

    def _assert_contents():
        file_contents = path.read_text()
        assert "import marimo as mo" in file_contents
        # Race condition with uv (seen in python 3.10)
        if file_contents.startswith("# ///"):
            file_contents = file_contents.split("# ///")[2].lstrip()
        assert file_contents.startswith(header.rstrip()), "Header was removed"
        assert "@app.cell(hide_code=True)" in file_contents
        assert "my_cell" in file_contents

    try_assert_n_times(5, _assert_contents)


@pytest.mark.flaky(reruns=5)
@with_session(SESSION_ID)
def test_save_with_invalid_file(client: TestClient) -> None:
    filename = get_session_manager(client).file_router.get_unique_file_key()
    assert filename
    path = Path(filename)
    assert path.exists()

    header = (
        '"""This is a docstring"""\n\n'
        + 'print("dont do this!")\n'
        + "# Linter ignore\n"
    )

    # Prepend a header to the file
    contents = path.read_text()
    contents = header + contents
    path.write_text(contents)

    response = client.post(
        "/api/kernel/save",
        headers=HEADERS,
        json={
            "cellIds": ["1"],
            "filename": str(path),
            "codes": ["import marimo as mo"],
            "names": ["my_cell"],
            "configs": [
                {
                    "hide_code": True,
                    "disabled": False,
                }
            ],
        },
    )

    assert response.status_code == 200, response.text
    assert "import marimo" in response.text

    def _assert_contents():
        file_contents = path.read_text()
        assert "@app.cell(hide_code=True)" in file_contents
        assert "my_cell" in file_contents

        # Race condition with uv (seen in python 3.10)
        if file_contents.startswith("# ///"):
            file_contents = file_contents.split("# ///")[2].lstrip()
        assert file_contents.startswith("import marimo"), (
            "Header was not removed"
        )

    try_assert_n_times(5, _assert_contents)


@with_session(SESSION_ID)
def test_save_file_cannot_rename(client: TestClient) -> None:
    response = client.post(
        "/api/kernel/save",
        headers=HEADERS,
        json={
            "cellIds": ["1"],
            "filename": "random_filename.py",
            "codes": ["import marimo as mo"],
            "names": ["my_cell"],
            "configs": [
                {
                    "hide_code": True,
                    "disabled": False,
                }
            ],
        },
    )
    assert response.status_code == 400
    body = response.json()
    assert body["detail"]
    assert "cannot rename" in body["detail"]


@pytest.mark.flaky(reruns=5)
@with_session(SESSION_ID)
def test_save_app_config(client: TestClient) -> None:
    filename = get_session_manager(client).file_router.get_unique_file_key()
    assert filename
    path = Path(filename)

    def _wait_for_file_reset():
        file_contents = path.read_text()
        assert 'marimo.App(width="medium"' not in file_contents

    try_assert_n_times(5, _wait_for_file_reset)

    response = client.post(
        "/api/kernel/save_app_config",
        headers=HEADERS,
        json={
            "config": {"width": "medium"},
        },
    )
    assert response.status_code == 200, response.text
    assert "import marimo" in response.text

    def _assert_contents():
        file_contents = path.read_text()
        assert 'marimo.App(width="medium"' in file_contents

    try_assert_n_times(5, _assert_contents)


@with_session(SESSION_ID)
def test_copy_file(client: TestClient) -> None:
    filename = get_session_manager(client).file_router.get_unique_file_key()
    assert filename
    path = Path(filename)
    assert path.exists()
    file_contents = path.read_text()
    assert "import marimo as mo" in file_contents
    assert 'marimo.App(width="full"' in file_contents

    filename_copy = f"_{os.path.basename(filename)}"
    copied_file = os.path.join(os.path.dirname(filename), filename_copy)
    response = client.post(
        "/api/kernel/copy",
        headers=HEADERS,
        json={
            "source": filename,
            "destination": copied_file,
        },
    )
    assert response.status_code == 200, response.text
    assert filename_copy in response.text

    def _assert_contents():
        file_contents = open(copied_file).read()
        assert "import marimo as mo" in file_contents
        assert 'marimo.App(width="full"' in file_contents

    try_assert_n_times(5, _assert_contents)


@with_websocket_session(SESSION_ID)
def test_rename_propagates(
    client: TestClient, websocket: WebSocketTestSession
) -> None:
    current_filename = get_session_manager(
        client
    ).file_router.get_unique_file_key()

    assert current_filename
    assert os.path.exists(current_filename)

    initial_response = client.post(
        "/api/kernel/run",
        headers=HEADERS,
        json={
            "cellIds": ["cell-1", "cell-2"],
            "codes": ["b = __file__", "a = 'x' + __file__"],
        },
    )
    assert initial_response.json() == {"success": True}
    assert initial_response.status_code == 200, initial_response.text

    variables = {}
    while len(variables) < 2:
        data = websocket.receive_json()
        if data["op"] == "variable-values":
            for var in data["data"]["variables"]:
                variables[var["name"]] = var["value"]

    # Variable outputs are truncated to 50 characters
    # current_filename can exceed this count on windows and OSX.
    assert ("x" + current_filename).startswith(variables["a"])
    assert current_filename.startswith(variables["b"])

    directory = os.path.dirname(current_filename)
    random_name = random.randint(0, 100000)
    new_filename = os.path.join(directory, f"test_{random_name}.py")

    response = client.post(
        "/api/kernel/rename",
        headers=HEADERS,
        json={
            "filename": new_filename,
        },
    )
    assert response.json() == {"success": True}
    assert response.status_code == 200, response.text

    variables = {}
    while len(variables) < 2:
        data = websocket.receive_json()
        if data["op"] == "variable-values":
            for var in data["data"]["variables"]:
                variables[var["name"]] = var["value"]

    assert ("x" + new_filename).startswith(variables["a"])
    assert new_filename.startswith(variables["b"])


# Edge case tests
@with_session(SESSION_ID)
def test_read_code_without_saved_file(client: TestClient) -> None:
    """Test read_code when file hasn't been saved yet."""
    from marimo._utils.http import HTTPStatus

    # Mock the session to have no file path
    with patch("marimo._server.api.endpoints.files.AppState") as mock_state:
        mock_session = Mock()
        mock_session.app_file_manager.path = None
        mock_state.return_value.require_current_session.return_value = (
            mock_session
        )

        response = client.post(
            "/api/kernel/read_code",
            headers=HEADERS,
            json={},
        )

        assert response.status_code == HTTPStatus.BAD_REQUEST
        assert (
            "File must be saved before downloading"
            in response.json()["detail"]
        )


@with_session(SESSION_ID)
def test_save_with_unicode_content(client: TestClient) -> None:
    """Test save endpoint with unicode and special characters."""
    filename = get_session_manager(client).file_router.get_unique_file_key()
    assert filename
    path = Path(filename)

    unicode_code = """# Unicode test: 你好世界 🌍 ñáéíóú
def test():
    return "Hello 世界" """

    response = client.post(
        "/api/kernel/save",
        headers=HEADERS,
        json={
            "cellIds": ["1"],
            "filename": str(path),
            "codes": [unicode_code],
            "names": ["unicode_cell"],
            "configs": [{}],
        },
    )
    assert response.status_code == 200

    def _assert_contents():
        file_contents = path.read_text(encoding="utf-8")
        assert "你好世界" in file_contents
        assert "🌍" in file_contents

    try_assert_n_times(5, _assert_contents)


@with_session(SESSION_ID)
def test_save_with_missing_required_fields(client: TestClient) -> None:
    """Test save endpoint with missing required fields."""
    # Missing codes field
    with pytest.raises(msgspec.ValidationError):
        client.post(
            "/api/kernel/save",
            headers=HEADERS,
            json={
                "cellIds": ["1"],
                "filename": "test.py",
                "names": ["my_cell"],
                "configs": [{}],
            },
        )


def test_endpoints_without_authentication(client: TestClient) -> None:
    """Test endpoints without proper authentication headers."""
    endpoints_and_methods = [
        ("/api/kernel/read_code", "post"),
        ("/api/kernel/rename", "post"),
        ("/api/kernel/save", "post"),
        ("/api/kernel/copy", "post"),
        ("/api/kernel/save_app_config", "post"),
    ]

    for endpoint, method in endpoints_and_methods:
        response = getattr(client, method)(
            endpoint,
            json={"test": "data"},
        )
        # Should require authentication
        assert response.status_code in [401, 403, 422]


@pytest.mark.skipif(
    is_windows(), reason="Windows doesn't support these edge case filenames"
)
@with_session(SESSION_ID)
def test_rename_with_edge_case_filenames(client: TestClient) -> None:
    """Test rename endpoint with unicode and spaces in filenames."""
    with tempfile.TemporaryDirectory() as tmpdir:
        for filename in EDGE_CASE_FILENAMES:
            current_filename = get_session_manager(
                client
            ).file_router.get_unique_file_key()
            assert current_filename

            new_path = Path(tmpdir) / filename
            response = client.post(
                "/api/kernel/rename",
                headers=HEADERS,
                json={
                    "filename": str(new_path),
                },
            )
            assert response.json() == {"success": True}

            def _new_path_exists():
                assert new_path.exists()  # noqa: B023
                # Ensure content is preserved and readable
                content = new_path.read_text(encoding="utf-8")  # noqa: B023
                assert "import marimo" in content

            try_assert_n_times(5, _new_path_exists)
