"""Tests for the LLM implementations in marimo._ai.llm._impl."""

from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any
from unittest.mock import MagicMock, patch

import pytest

from marimo._ai._types import ChatMessage, ChatModelConfig, TextPart
from marimo._ai.llm._impl import (
    DEFAULT_SYSTEM_MESSAGE,
    anthropic,
    bedrock,
    google,
    groq,
    openai,
    pydantic_ai,
    simple,
)
from marimo._dependencies.dependencies import DependencyManager

if TYPE_CHECKING:
    from pydantic_ai.settings import ModelSettings


@pytest.fixture
def mock_openai_client():
    """Fixture for mocking the OpenAI client."""
    with patch("openai.OpenAI") as mock_openai_class:
        mock_client = MagicMock()
        mock_openai_class.return_value = mock_client

        # Setup the streaming response structure
        mock_chunk = MagicMock()
        mock_choice = MagicMock()
        mock_delta = MagicMock()
        mock_delta.content = "Test response"
        mock_choice.delta = mock_delta
        mock_chunk.choices = [mock_choice]

        # Return an iterable for streaming
        mock_client.chat.completions.create.return_value = [mock_chunk]

        yield mock_client, mock_openai_class


@pytest.fixture
def mock_groq_client():
    """Fixture for mocking the Groq client."""
    with patch("groq.Groq") as mock_groq_class:
        mock_client = MagicMock()
        mock_groq_class.return_value = mock_client

        # Setup the response structure
        mock_response = MagicMock()
        mock_choice = MagicMock()
        mock_message = MagicMock()
        mock_message.content = "Test response"
        mock_choice.message = mock_message
        mock_response.choices = [mock_choice]
        mock_client.chat.completions.create.return_value = mock_response

        yield mock_client, mock_groq_class


@pytest.fixture
def mock_anthropic_client():
    """Fixture for mocking the Anthropic client."""
    with patch("anthropic.Anthropic") as mock_anthropic_class:
        mock_client = MagicMock()
        mock_anthropic_class.return_value = mock_client

        # Setup the response structure
        mock_response = MagicMock()
        mock_message = MagicMock()
        mock_message.content = [MagicMock(text="Test response")]
        mock_response.content = mock_message.content
        mock_client.messages.create.return_value = mock_response

        yield mock_client, mock_anthropic_class


@pytest.fixture
def mock_google_client():
    """Fixture for mocking the Google client."""
    with patch("google.genai.Client") as mock_google_class:
        mock_client = MagicMock()
        mock_google_class.return_value = mock_client

        # Setup the response structure
        mock_response = MagicMock()
        mock_response.text = "Test response"
        mock_client.models.generate_content.return_value = mock_response

        yield mock_client, mock_google_class


@pytest.fixture
def mock_azure_openai_client():
    """Fixture for mocking the Azure OpenAI client."""
    with patch("openai.AzureOpenAI") as mock_azure_openai_class:
        mock_client = MagicMock()
        mock_azure_openai_class.return_value = mock_client

        # Setup the streaming response structure
        mock_chunk = MagicMock()
        mock_choice = MagicMock()
        mock_delta = MagicMock()
        mock_delta.content = "Test response"
        mock_choice.delta = mock_delta
        mock_chunk.choices = [mock_choice]

        # Return an iterable for streaming
        mock_client.chat.completions.create.return_value = [mock_chunk]

        yield mock_client, mock_azure_openai_class


@pytest.fixture
def mock_litellm_completion():
    """Fixture for mocking the OpenAI client."""
    with patch("litellm.completion") as mock_litellm_completion:
        # Setup the response structure
        mock_response = MagicMock()
        mock_choice = MagicMock()
        mock_message = MagicMock()
        mock_message.content = "Test response"
        mock_choice.message = mock_message
        mock_response.choices = [mock_choice]
        mock_litellm_completion.return_value = mock_response

        yield mock_litellm_completion


@pytest.fixture
def test_messages():
    """Fixture for test messages."""
    return [ChatMessage(role="user", content="Test prompt")]


@pytest.fixture
def test_config():
    """Fixture for test configuration."""
    return ChatModelConfig(
        max_tokens=100,
        temperature=0.7,
        top_p=0.9,
        frequency_penalty=0.5,
        presence_penalty=0.5,
    )


def test_simple_model() -> None:
    """Test the simple model wrapper."""
    model = simple(lambda x: x * 2)
    assert (
        model([ChatMessage(role="user", content="hey")], ChatModelConfig())
        == "heyhey"
    )

    assert (
        model(
            [
                ChatMessage(role="user", content="hey", attachments=[]),
                ChatMessage(role="user", content="goodbye", attachments=[]),
            ],
            ChatModelConfig(),
        )
        == "goodbyegoodbye"
    )


@pytest.fixture(autouse=True)
def mock_environment_variables():
    """Mock environment variables."""
    with patch.dict(
        os.environ,
        {
            "OPENAI_API_KEY": "test-key",
            "ANTHROPIC_API_KEY": "test-key",
            "GOOGLE_AI_API_KEY": "test-key",
        },
        clear=True,
    ):
        yield


@pytest.mark.skipif(
    DependencyManager.openai.has(), reason="OpenAI is installed"
)
def test_openai_require() -> None:
    """Test that openai.require raises ModuleNotFoundError."""
    model = openai("gpt-4")
    messages = [ChatMessage(role="user", content="Test prompt")]
    config = ChatModelConfig()
    with pytest.raises(ModuleNotFoundError):
        model(messages, config)


@pytest.mark.skipif(
    DependencyManager.anthropic.has(), reason="Anthropic is installed"
)
def test_anthropic_require() -> None:
    """Test that anthropic.require raises ModuleNotFoundError."""
    model = anthropic("claude-3-opus-20240229")
    messages = [ChatMessage(role="user", content="Test prompt")]
    config = ChatModelConfig()
    with pytest.raises(ModuleNotFoundError):
        model(messages, config)


@pytest.mark.skipif(
    DependencyManager.google_ai.has(), reason="Google AI is installed"
)
def test_google_require() -> None:
    """Test that google.require raises ModuleNotFoundError."""
    model = google("gemini-2.5-flash-preview-05-20")
    messages = [ChatMessage(role="user", content="Test prompt")]
    config = ChatModelConfig()
    with pytest.raises(ModuleNotFoundError):
        model(messages, config)


@pytest.mark.skipif(DependencyManager.groq.has(), reason="Groq is installed")
def test_groq_require() -> None:
    """Test that groq.require raises ModuleNotFoundError."""
    model = groq("llama3-70b-8192")
    messages = [ChatMessage(role="user", content="Test prompt")]
    config = ChatModelConfig()
    with pytest.raises(ModuleNotFoundError):
        model(messages, config)


@pytest.mark.skipif(
    not DependencyManager.openai.has(), reason="OpenAI is not installed"
)
class TestOpenAI:
    """Tests for the OpenAI class."""

    def test_init(self):
        """Test initialization of the openai class."""
        # Test default initialization
        model = openai("gpt-4")
        assert model.model == "gpt-4"
        assert model.system_message == DEFAULT_SYSTEM_MESSAGE
        assert model.api_key is None
        assert model.base_url is None

        # Test custom initialization
        model = openai(
            "gpt-4",
            system_message="Custom system message",
            api_key="test-key",
            base_url="https://example.com",
        )
        assert model.model == "gpt-4"
        assert model.system_message == "Custom system message"
        assert model.api_key == "test-key"
        assert model.base_url == "https://example.com"

    def test_call(self, mock_openai_client, test_messages, test_config):
        """Test calling the openai class with standard OpenAI client."""
        mock_client, mock_openai_class = mock_openai_client

        # Create model with API key to avoid _require_api_key
        model = openai("gpt-4", api_key="test-key")

        result_gen = model(test_messages, test_config)
        # Consume the generator to get the final result
        result = list(result_gen)[-1] if result_gen else ""

        # Verify result
        assert result == "Test response"

        # Verify client initialization
        mock_openai_class.assert_called_once_with(
            api_key="test-key", base_url=None
        )

        # Verify API call
        mock_client.chat.completions.create.assert_called_once()
        call_args = mock_client.chat.completions.create.call_args[1]
        assert call_args["model"] == "gpt-4"
        assert len(call_args["messages"]) == 2
        assert call_args["messages"][0]["role"] == "system"
        assert call_args["messages"][0]["content"] == [
            {"type": "text", "text": DEFAULT_SYSTEM_MESSAGE}
        ]
        assert call_args["messages"][1]["role"] == "user"
        assert call_args["messages"][1]["content"] == [
            {"type": "text", "text": "Test prompt"}
        ]
        assert call_args["max_completion_tokens"] == 100
        # Use pytest.approx for floating point comparisons
        assert call_args["temperature"] == pytest.approx(0.7)
        assert call_args["top_p"] == pytest.approx(0.9)
        assert call_args["frequency_penalty"] == pytest.approx(0.5)
        assert call_args["presence_penalty"] == pytest.approx(0.5)

    def test_call_with_base_url(
        self, mock_openai_client, test_messages, test_config
    ):
        """Test calling the openai class with a custom base URL."""
        mock_client, mock_openai_class = mock_openai_client

        # Create model with API key and base URL
        model = openai(
            "gpt-4", api_key="test-key", base_url="https://example.com"
        )

        result_gen = model(test_messages, test_config)
        # Consume the generator to get the final result
        result = list(result_gen)[-1] if result_gen else ""

        # Verify result
        assert result == "Test response"

        # Verify client initialization with base URL
        mock_openai_class.assert_called_once_with(
            api_key="test-key", base_url="https://example.com"
        )

    def test_call_azure(
        self, mock_azure_openai_client, test_messages, test_config
    ):
        """Test calling the openai class with Azure OpenAI."""
        mock_client, mock_azure_openai_class = mock_azure_openai_client

        # Create model with API key and Azure URL
        model = openai(
            "gpt-4",
            api_key="test-key",
            base_url="https://example.openai.azure.com/openai/deployments/gpt-4/chat/completions?api-version=2023-05-15",
        )

        result_gen = model(test_messages, test_config)
        # Consume the generator to get the final result
        result = list(result_gen)[-1] if result_gen else ""

        # Verify result
        assert result == "Test response"

        # Verify Azure client initialization
        mock_azure_openai_class.assert_called_once_with(
            api_key="test-key",
            api_version="2023-05-15",
            azure_endpoint="https://example.openai.azure.com",
        )

        # Verify API call
        mock_client.chat.completions.create.assert_called_once()
        call_args = mock_client.chat.completions.create.call_args[1]
        assert call_args["model"] == "gpt-4"

    def test_call_with_empty_response(
        self, mock_openai_client, test_messages, test_config
    ):
        """Test calling the openai class with an empty response."""
        mock_client, _ = mock_openai_client

        # For streaming, we need to mock a streaming response with no content
        # Create an empty chunk
        mock_chunk = MagicMock()
        mock_chunk.choices = []
        mock_client.chat.completions.create.return_value = [mock_chunk]

        # Create model with API key
        model = openai("gpt-4", api_key="test-key")

        result_gen = model(test_messages, test_config)
        # Consume the generator to get the final result
        result_list = list(result_gen)
        result = result_list[-1] if result_list else ""

        # Verify empty string is returned when content is None
        assert result == ""

    @patch.dict(os.environ, {"OPENAI_API_KEY": "env-key"})
    def test_require_api_key_env(self):
        """Test _require_api_key with environment variable."""
        model = openai("gpt-4")
        assert model._require_api_key == "env-key"

    @patch.dict(os.environ, {}, clear=True)
    @patch("marimo._runtime.context.types.get_context")
    def test_require_api_key_config(self, mock_get_context):
        """Test _require_api_key with config."""
        mock_context = MagicMock()
        mock_context.marimo_config = {
            "ai": {"open_ai": {"api_key": "config-key"}}
        }
        mock_get_context.return_value = mock_context

        model = openai("gpt-4")
        assert model._require_api_key == "config-key"

    @patch.dict(os.environ, {}, clear=True)
    @patch("marimo._runtime.context.types.get_context")
    def test_require_api_key_missing(self, mock_get_context):
        """Test _require_api_key with missing key."""
        mock_context = MagicMock()
        mock_context.marimo_config = {"ai": {"open_ai": {"api_key": ""}}}
        mock_get_context.return_value = mock_context

        model = openai("gpt-4")
        with pytest.raises(ValueError, match="openai api key not provided"):
            _ = model._require_api_key

    @patch(
        "marimo._dependencies.dependencies.DependencyManager.openai.require"
    )
    def test_dependency_check(self, mock_require):
        """Test that the dependency check is performed."""
        # Create model with API key to avoid _require_api_key issues
        model = openai("gpt-4", api_key="test-key")

        # Mock OpenAI and related imports
        with patch("openai.OpenAI") as mock_openai_class:
            # Setup mock client and response
            mock_client = MagicMock()
            mock_openai_class.return_value = mock_client
            mock_response = MagicMock()
            mock_choice = MagicMock()
            mock_message = MagicMock()
            mock_message.content = "Test response"
            mock_choice.message = mock_message
            mock_response.choices = [mock_choice]
            mock_client.chat.completions.create.return_value = mock_response

            # Call the model
            model(
                [ChatMessage(role="user", content="Test")], ChatModelConfig()
            )

            # Verify dependency check was called
            mock_require.assert_called_once_with(
                "chat model requires openai. `pip install openai`"
            )

    def test_convert_to_openai_messages(self):
        """Test that messages are properly converted for OpenAI."""
        # Create model with API key to avoid _require_api_key issues
        model = openai("gpt-4", api_key="test-key")

        with patch("openai.OpenAI") as mock_openai_class:
            # Setup mock client and response
            mock_client = MagicMock()
            mock_openai_class.return_value = mock_client
            mock_response = MagicMock()
            mock_choice = MagicMock()
            mock_message = MagicMock()
            mock_message.content = "Test response"
            mock_choice.message = mock_message
            mock_response.choices = [mock_choice]
            mock_client.chat.completions.create.return_value = mock_response

            # Test with multiple messages
            messages = [
                ChatMessage(
                    role="system",
                    content="Custom system",
                    parts=[TextPart(type="text", text="Custom system")],
                ),
                ChatMessage(
                    role="user",
                    content="Hello",
                    parts=[TextPart(type="text", text="Hello")],
                ),
                ChatMessage(
                    role="assistant",
                    content="Hi there",
                    parts=[TextPart(type="text", text="Hi there")],
                ),
                ChatMessage(
                    role="user",
                    content="How are you?",
                    parts=[TextPart(type="text", text="How are you?")],
                ),
            ]

            model(messages, ChatModelConfig())

            # Verify the messages were properly converted
            call_args = mock_client.chat.completions.create.call_args[1]
            assert len(call_args["messages"]) == 5  # system + all messages

            assert call_args["messages"] == [
                {
                    "role": "system",
                    "content": [
                        {"type": "text", "text": DEFAULT_SYSTEM_MESSAGE}
                    ],
                },
                {
                    "role": "system",
                    "content": [{"type": "text", "text": "Custom system"}],
                },
                {
                    "role": "user",
                    "content": [{"type": "text", "text": "Hello"}],
                },
                {
                    "role": "assistant",
                    "content": [{"type": "text", "text": "Hi there"}],
                },
                {
                    "role": "user",
                    "content": [{"type": "text", "text": "How are you?"}],
                },
            ]


@pytest.mark.skipif(
    not DependencyManager.groq.has(), reason="Groq is not installed"
)
class TestGroq:
    """Tests for the Groq class."""

    def test_init(self):
        """Test initialization of the groq class."""
        # Test default initialization
        model = groq("llama3-8b-8192")
        assert model.model == "llama3-8b-8192"
        assert model.system_message == DEFAULT_SYSTEM_MESSAGE
        assert model.api_key is None
        assert model.base_url is None

        # Test custom initialization
        model = groq(
            "llama3-8b-8192",
            system_message="Custom system message",
            api_key="test-key",
            base_url="https://example.com",
        )
        assert model.model == "llama3-8b-8192"
        assert model.system_message == "Custom system message"
        assert model.api_key == "test-key"
        assert model.base_url == "https://example.com"

    def test_call(self, mock_groq_client, test_messages, test_config):
        """Test calling the groq class with standard Groq client."""
        mock_client, mock_groq_class = mock_groq_client

        # Create model with API key to avoid _require_api_key
        model = groq("llama3-8b-8192", api_key="test-key")

        result = model(test_messages, test_config)

        # Verify result
        assert result == "Test response"

        # Verify client initialization
        mock_groq_class.assert_called_once_with(
            api_key="test-key", base_url=None
        )

        # Verify API call
        mock_client.chat.completions.create.assert_called_once()
        call_args = mock_client.chat.completions.create.call_args[1]
        assert call_args["model"] == "llama3-8b-8192"
        assert len(call_args["messages"]) == 2
        assert call_args["messages"][0]["role"] == "system"
        assert call_args["messages"][0]["content"] == DEFAULT_SYSTEM_MESSAGE
        assert call_args["messages"][1]["role"] == "user"
        assert call_args["messages"][1]["content"] == "Test prompt"
        assert call_args["max_tokens"] == 100
        # Use pytest.approx for floating point comparisons
        assert call_args["temperature"] == pytest.approx(0.7)
        assert call_args["top_p"] == pytest.approx(0.9)
        assert call_args["frequency_penalty"] == pytest.approx(0.5)
        assert call_args["presence_penalty"] == pytest.approx(0.5)

    def test_call_with_base_url(
        self, mock_groq_client, test_messages, test_config
    ):
        """Test calling the groq class with a custom base URL."""
        mock_client, mock_groq_class = mock_groq_client

        # Create model with API key and base URL
        model = groq(
            "llama3-8b-8192",
            api_key="test-key",
            base_url="https://example.com",
        )

        result = model(test_messages, test_config)

        # Verify result
        assert result == "Test response"

        # Verify client initialization with base URL
        mock_groq_class.assert_called_once_with(
            api_key="test-key", base_url="https://example.com"
        )

    def test_call_with_empty_response(
        self, mock_groq_client, test_messages, test_config
    ):
        """Test calling the groq class with an empty response."""
        mock_client, _ = mock_groq_client

        # Modify the mock to return an empty content
        mock_client.chat.completions.create.return_value.choices[
            0
        ].message.content = None

        # Create model with API key
        model = groq("llama3-8b-8192", api_key="test-key")

        result = model(test_messages, test_config)

        # Verify empty string is returned when content is None
        assert result == ""

    @patch.dict(os.environ, {"GROQ_API_KEY": "env-key"})
    def test_require_api_key_env(self):
        """Test _require_api_key with environment variable."""
        model = groq("llama3-8b-8192")
        assert model._require_api_key == "env-key"

    @patch.dict(os.environ, {}, clear=True)
    @patch("marimo._runtime.context.types.get_context")
    def test_require_api_key_config(self, mock_get_context):
        """Test _require_api_key with config."""
        mock_context = MagicMock()
        mock_context.marimo_config = {
            "ai": {"groq": {"api_key": "config-key"}}
        }
        mock_get_context.return_value = mock_context

        model = groq("llama3-8b-8192")
        assert model._require_api_key == "config-key"

    @patch.dict(os.environ, {}, clear=True)
    @patch("marimo._runtime.context.types.get_context")
    def test_require_api_key_missing(self, mock_get_context):
        """Test _require_api_key with missing key."""
        mock_context = MagicMock()
        mock_context.marimo_config = {"ai": {"groq": {"api_key": ""}}}
        mock_get_context.return_value = mock_context

        model = groq("llama3-8b-8192")
        with pytest.raises(ValueError, match="groq api key not provided"):
            _ = model._require_api_key

    @patch("marimo._dependencies.dependencies.DependencyManager.groq.require")
    def test_dependency_check(self, mock_require):
        """Test that the dependency check is performed."""
        # Create model with API key to avoid _require_api_key issues
        model = groq("llama3-8b-8192", api_key="test-key")

        # Mock Groq and related imports
        with patch("groq.Groq") as mock_groq_class:
            # Setup mock client and response
            mock_client = MagicMock()
            mock_groq_class.return_value = mock_client
            mock_response = MagicMock()
            mock_choice = MagicMock()
            mock_message = MagicMock()
            mock_message.content = "Test response"
            mock_choice.message = mock_message
            mock_response.choices = [mock_choice]
            mock_client.chat.completions.create.return_value = mock_response

            # Call the model
            model(
                [ChatMessage(role="user", content="Test")], ChatModelConfig()
            )

            # Verify dependency check was called
            mock_require.assert_called_once_with(
                "chat model requires groq. `pip install groq`"
            )

    def test_convert_to_groq_messages(self):
        """Test that messages are properly converted for Groq."""
        # Create model with API key to avoid _require_api_key issues
        model = groq("llama3-8b-8192", api_key="test-key")

        with patch("groq.Groq") as mock_groq_class:
            # Setup mock client and response
            mock_client = MagicMock()
            mock_groq_class.return_value = mock_client
            mock_response = MagicMock()
            mock_choice = MagicMock()
            mock_message = MagicMock()
            mock_message.content = "Test response"
            mock_choice.message = mock_message
            mock_response.choices = [mock_choice]
            mock_client.chat.completions.create.return_value = mock_response

            # Test with multiple messages
            messages = [
                ChatMessage(role="system", content="Custom system"),
                ChatMessage(role="user", content="Hello"),
                ChatMessage(role="assistant", content="Hi there"),
                ChatMessage(role="user", content="How are you?"),
            ]

            model(messages, ChatModelConfig())

            # Verify the messages were properly converted
            call_args = mock_client.chat.completions.create.call_args[1]
            assert len(call_args["messages"]) == 5  # system + all messages
            assert call_args["messages"][0]["role"] == "system"
            assert (
                call_args["messages"][0]["content"] == DEFAULT_SYSTEM_MESSAGE
            )
            assert call_args["messages"][1]["role"] == "system"
            assert call_args["messages"][1]["content"] == "Custom system"
            assert call_args["messages"][2]["role"] == "user"
            assert call_args["messages"][2]["content"] == "Hello"
            assert call_args["messages"][3]["role"] == "assistant"
            assert call_args["messages"][3]["content"] == "Hi there"
            assert call_args["messages"][4]["role"] == "user"
            assert call_args["messages"][4]["content"] == "How are you?"


@pytest.mark.skipif(
    not DependencyManager.google_ai.has(), reason="Google AI is not installed"
)
class TestGoogle:
    def test_init(self) -> None:
        """Test initialization of the google class."""
        model = google("gemini-2.5-flash-preview-05-20")
        assert model.model == "gemini-2.5-flash-preview-05-20"
        assert model.system_message == DEFAULT_SYSTEM_MESSAGE
        assert model.api_key is None

        model = google(
            "gemini-2.5-flash-preview-05-20",
            system_message="Custom system message",
            api_key="test-key",
        )
        assert model.model == "gemini-2.5-flash-preview-05-20"
        assert model.system_message == "Custom system message"
        assert model.api_key == "test-key"

    @patch.object(google, "_require_api_key")
    @patch("google.genai.Client")
    def test_call(
        self,
        mock_genai_client_class: MagicMock,
        mock_require_api_key: MagicMock,
    ) -> None:
        """Test calling the google class."""
        mock_require_api_key.return_value = "test-key"
        mock_client = MagicMock()
        mock_genai_client_class.return_value = mock_client

        # Setup streaming response
        mock_chunk = MagicMock()
        mock_chunk.text = "Test response"
        mock_client.models.generate_content_stream.return_value = [mock_chunk]

        model = google("gemini-2.5-flash-preview-05-20")
        # Patch the _require_api_key property to return the test key directly
        with patch.object(model, "_require_api_key", "test-key"):
            messages = [ChatMessage(role="user", content="Test prompt")]
            config = ChatModelConfig(
                max_tokens=100,
                temperature=0.7,
                top_p=0.9,
                top_k=10,
                frequency_penalty=0.5,
                presence_penalty=0.5,
            )

            result_gen = model(messages, config)
            # Consume the generator to get the final result
            result = list(result_gen)[-1] if result_gen else ""
            assert result == "Test response"

            mock_genai_client_class.assert_called_once_with(api_key="test-key")

        mock_client.models.generate_content_stream.assert_called_once()
        call_args = mock_client.models.generate_content_stream.call_args[1]
        assert call_args["model"] == "gemini-2.5-flash-preview-05-20"
        config_arg = call_args["config"]
        assert config_arg["system_instruction"] == DEFAULT_SYSTEM_MESSAGE
        assert config_arg["max_output_tokens"] == 100
        assert config_arg["temperature"] == 0.7
        assert config_arg["top_p"] == 0.9
        assert config_arg["top_k"] == 10
        assert config_arg["frequency_penalty"] == 0.5
        assert config_arg["presence_penalty"] == 0.5

    @patch.dict(os.environ, {"GOOGLE_AI_API_KEY": "env-key"})
    def test_require_api_key_env(self) -> None:
        """Test _require_api_key with environment variable."""
        model = google("gemini-2.5-flash-preview-05-20")
        assert model._require_api_key == "env-key"

    @patch.dict(os.environ, {}, clear=True)
    @patch("marimo._runtime.context.types.get_context")
    def test_require_api_key_config(self, mock_get_context: MagicMock) -> None:
        """Test _require_api_key with config."""
        mock_context = MagicMock()
        mock_context.marimo_config = {
            "ai": {"google": {"api_key": "config-key"}}
        }
        mock_get_context.return_value = mock_context

        model = google("gemini-2.5-flash-preview-05-20")
        assert model._require_api_key == "config-key"

    @patch.dict(os.environ, {}, clear=True)
    @patch("marimo._runtime.context.types.get_context")
    def test_require_api_key_missing(
        self, mock_get_context: MagicMock
    ) -> None:
        """Test _require_api_key with missing key."""
        mock_context = MagicMock()
        mock_context.marimo_config = {"ai": {"google": {"api_key": ""}}}
        mock_get_context.return_value = mock_context

        model = google("gemini-2.5-flash-preview-05-20")
        with pytest.raises(ValueError):
            _ = model._require_api_key


@pytest.mark.skipif(
    not DependencyManager.anthropic.has(), reason="Anthropic is not installed"
)
class TestAnthropic:
    def test_init(self) -> None:
        """Test initialization of the anthropic class."""
        model = anthropic("claude-3-opus-20240229")
        assert model.model == "claude-3-opus-20240229"
        assert model.system_message == DEFAULT_SYSTEM_MESSAGE
        assert model.api_key is None
        assert model.base_url is None

        model = anthropic(
            "claude-3-opus-20240229",
            system_message="Custom system message",
            api_key="test-key",
            base_url="https://example.com",
        )
        assert model.model == "claude-3-opus-20240229"
        assert model.system_message == "Custom system message"
        assert model.api_key == "test-key"
        assert model.base_url == "https://example.com"

    @patch.object(anthropic, "_require_api_key")
    @patch("anthropic.Anthropic")
    def test_call(
        self, mock_anthropic_class: MagicMock, mock_require_api_key: MagicMock
    ) -> None:
        """Test calling the anthropic class."""
        mock_require_api_key.return_value = "test-key"
        mock_client = MagicMock()
        mock_anthropic_class.return_value = mock_client

        # Setup streaming response using context manager
        mock_stream = MagicMock()
        mock_stream.__enter__ = MagicMock(return_value=mock_stream)
        mock_stream.__exit__ = MagicMock(return_value=None)
        mock_stream.text_stream = ["Test response"]
        mock_client.messages.stream.return_value = mock_stream

        model = anthropic("claude-3-opus-20240229")
        # Patch the _require_api_key property to return the test key directly
        with patch.object(model, "_require_api_key", "test-key"):
            messages = [ChatMessage(role="user", content="Test prompt")]
            config = ChatModelConfig(
                max_tokens=100,
                temperature=0.7,
                top_p=0.9,
                top_k=10,
            )

            result_gen = model(messages, config)
            # Consume the generator to get the final result
            result = list(result_gen)[-1] if result_gen else ""
            assert result == "Test response"

            mock_anthropic_class.assert_called_once_with(
                api_key="test-key", base_url=None
            )
        mock_client.messages.stream.assert_called_once()
        call_args = mock_client.messages.stream.call_args[1]
        assert call_args["model"] == "claude-3-opus-20240229"
        assert call_args["system"] == DEFAULT_SYSTEM_MESSAGE
        assert call_args["max_tokens"] == 100
        assert call_args["temperature"] == 0.7
        assert call_args["top_p"] == 0.9
        assert call_args["top_k"] == 10

    @patch.object(anthropic, "_require_api_key")
    @patch("anthropic.Anthropic")
    def test_call_tool_use(
        self, mock_anthropic_class: MagicMock, mock_require_api_key: MagicMock
    ) -> None:
        """Test calling the anthropic class with tool use response.

        Note: With streaming API, tool use may not be supported in the same way.
        This test is kept for backwards compatibility but may need revision.
        """
        mock_require_api_key.return_value = "test-key"
        mock_client = MagicMock()
        mock_anthropic_class.return_value = mock_client

        # Setup streaming response with empty text (tool use case)
        mock_stream = MagicMock()
        mock_stream.__enter__ = MagicMock(return_value=mock_stream)
        mock_stream.__exit__ = MagicMock(return_value=None)
        mock_stream.text_stream = []  # No text for tool use
        mock_client.messages.stream.return_value = mock_stream

        model = anthropic("claude-3-opus-20240229")
        messages = [ChatMessage(role="user", content="Test prompt")]
        config = ChatModelConfig()

        result_gen = model(messages, config)
        # Consume the generator
        result_list = list(result_gen)
        # For empty text stream, expect empty result
        result = result_list[-1] if result_list else ""
        assert result == ""

    @patch.object(anthropic, "_require_api_key")
    @patch("anthropic.Anthropic")
    def test_call_empty_content(
        self, mock_anthropic_class: MagicMock, mock_require_api_key: MagicMock
    ) -> None:
        """Test calling the anthropic class with empty content."""
        mock_require_api_key.return_value = "test-key"
        mock_client = MagicMock()
        mock_anthropic_class.return_value = mock_client

        # Setup streaming response with no content
        mock_stream = MagicMock()
        mock_stream.__enter__ = MagicMock(return_value=mock_stream)
        mock_stream.__exit__ = MagicMock(return_value=None)
        mock_stream.text_stream = []
        mock_client.messages.stream.return_value = mock_stream

        model = anthropic("claude-3-opus-20240229")
        messages = [ChatMessage(role="user", content="Test prompt")]
        config = ChatModelConfig()

        result_gen = model(messages, config)
        # Consume the generator
        result_list = list(result_gen)
        result = result_list[-1] if result_list else ""
        assert result == ""

    def test_supports_temperature(self) -> None:
        """Test supports_temperature method."""
        model = anthropic("claude-3-opus-20240229")
        assert model.supports_temperature("claude-3-opus-20240229") is True
        assert model.supports_temperature("claude-3-sonnet-20240229") is True
        assert model.supports_temperature("claude-3-haiku-20240307") is True

        # Reasoning models (>4.0) don't support temperature
        assert model.supports_temperature("claude-sonnet-4-5") is False
        assert model.supports_temperature("claude-opus-4-5") is False
        assert model.supports_temperature("claude-4-opus") is False

    @patch.object(anthropic, "_require_api_key")
    @patch("anthropic.Anthropic")
    def test_call_without_temperature_for_reasoning_model(
        self, mock_anthropic_class: MagicMock, mock_require_api_key: MagicMock
    ) -> None:
        """Test that temperature is not included for reasoning models."""
        mock_require_api_key.return_value = "test-key"
        mock_client = MagicMock()
        mock_anthropic_class.return_value = mock_client

        # Setup streaming response
        mock_stream = MagicMock()
        mock_stream.__enter__ = MagicMock(return_value=mock_stream)
        mock_stream.__exit__ = MagicMock(return_value=None)
        mock_stream.text_stream = ["Test response"]
        mock_client.messages.stream.return_value = mock_stream

        model = anthropic("claude-sonnet-4-5")
        messages = [ChatMessage(role="user", content="Test prompt")]
        config = ChatModelConfig(
            max_tokens=100,
            temperature=0.7,
            top_p=0.9,
            top_k=10,
        )

        result_gen = model(messages, config)
        list(result_gen)  # Consume the generator

        mock_client.messages.stream.assert_called_once()
        call_args = mock_client.messages.stream.call_args[1]
        assert call_args["model"] == "claude-sonnet-4-5"
        assert call_args["system"] == DEFAULT_SYSTEM_MESSAGE
        assert call_args["max_tokens"] == 100
        # Temperature should not be included for reasoning models
        assert "temperature" not in call_args
        assert call_args["top_p"] == 0.9
        assert call_args["top_k"] == 10

    @patch.dict(os.environ, {"ANTHROPIC_API_KEY": "env-key"})
    def test_require_api_key_env(self) -> None:
        """Test _require_api_key with environment variable."""
        model = anthropic("claude-3-opus-20240229")
        assert model._require_api_key == "env-key"

    @patch.dict(os.environ, {}, clear=True)
    @patch("marimo._runtime.context.types.get_context")
    def test_require_api_key_config(self, mock_get_context: MagicMock) -> None:
        """Test _require_api_key with config."""
        mock_context = MagicMock()
        mock_context.marimo_config = {
            "ai": {"anthropic": {"api_key": "config-key"}}
        }
        mock_get_context.return_value = mock_context

        model = anthropic("claude-3-opus-20240229")
        assert model._require_api_key == "config-key"

    @patch.dict(os.environ, {}, clear=True)
    @patch("marimo._runtime.context.types.get_context")
    def test_require_api_key_missing(
        self, mock_get_context: MagicMock
    ) -> None:
        """Test _require_api_key with missing key."""
        mock_context = MagicMock()
        mock_context.marimo_config = {"ai": {"anthropic": {"api_key": ""}}}
        mock_get_context.return_value = mock_context

        model = anthropic("claude-3-opus-20240229")
        with pytest.raises(ValueError):
            _ = model._require_api_key


@pytest.mark.skipif(
    not DependencyManager.boto3.has() or not DependencyManager.litellm.has(),
    reason="boto3 or litellm is not installed",
)
class TestBedrock:
    """Test the Bedrock model class"""

    def test_init(self):
        """Test initialization of the bedrock model class"""
        model = bedrock(
            "anthropic.claude-3-sonnet-20240229",
            system_message="Test system message",
            region_name="us-east-1",
        )

        # bedrock automatically prefixes with bedrock/ for litellm usage
        assert model.model == "bedrock/anthropic.claude-3-sonnet-20240229"
        assert model.system_message == "Test system message"
        assert model.region_name == "us-east-1"
        assert model.profile_name is None
        assert model.aws_access_key_id is None
        assert model.aws_secret_access_key is None

    def test_init_with_credentials(self):
        """Test initialization with explicit credentials"""
        model = bedrock(
            "anthropic.claude-3-sonnet-20240229",
            aws_access_key_id="test-key",
            aws_secret_access_key="test-secret",
        )

        assert model.aws_access_key_id == "test-key"
        assert model.aws_secret_access_key == "test-secret"

    def test_init_with_profile(self):
        """Test initialization with AWS profile"""
        model = bedrock(
            "anthropic.claude-3-sonnet-20240229",
            profile_name="test-profile",
        )

        assert model.profile_name == "test-profile"

    @pytest.mark.xfail(
        reason="latest litellm and openai are not compatible",
    )
    def test_call(self, mock_litellm_completion, test_messages, test_config):
        """Test calling the bedrock class with LiteLLM client."""
        model_name = "anthropic.claude-3-sonnet-20240229"

        # Create model with API key to avoid _require_api_key
        model = bedrock(model_name)

        result = model(test_messages, test_config)

        # Verify result
        assert result == "Test response"

        # Verify API call
        mock_litellm_completion.assert_called_once()
        call_args = mock_litellm_completion.call_args[1]
        assert call_args["model"] == f"bedrock/{model_name}"
        assert len(call_args["messages"]) == 2
        assert call_args["messages"][0]["role"] == "system"
        assert call_args["messages"][0]["content"] == DEFAULT_SYSTEM_MESSAGE
        assert call_args["messages"][1]["role"] == "user"
        assert call_args["messages"][1]["content"] == "Test prompt"
        assert call_args["max_tokens"] == 100
        # Use pytest.approx for floating point comparisons
        assert call_args["temperature"] == pytest.approx(0.7)
        assert call_args["top_p"] == pytest.approx(0.9)
        assert call_args["frequency_penalty"] == pytest.approx(0.5)
        assert call_args["presence_penalty"] == pytest.approx(0.5)


@pytest.mark.skipif(
    not DependencyManager.pydantic_ai.has(),
    reason="Pydantic AI is not installed",
)
class TestPydanticAI:
    """Tests for the pydantic_ai ChatModel class."""

    def test_init(self):
        """Test initialization of the pydantic_ai class."""
        mock_agent = MagicMock()
        model = pydantic_ai(mock_agent)
        assert model.agent is mock_agent

    def test_get_model_settings_full_config(self):
        """Test _get_model_settings with all config options."""

        mock_agent = MagicMock()
        model = pydantic_ai(mock_agent)

        config = ChatModelConfig(
            max_tokens=100,
            temperature=0.7,
            top_p=0.9,
            frequency_penalty=0.5,
            presence_penalty=0.3,
        )

        settings: ModelSettings = model._get_model_settings(config)

        assert settings.get("max_tokens") == 100
        assert settings.get("temperature") == pytest.approx(0.7)
        assert settings.get("top_p") == pytest.approx(0.9)
        assert settings.get("frequency_penalty") == pytest.approx(0.5)
        assert settings.get("presence_penalty") == pytest.approx(0.3)

    def test_get_model_settings_partial_config(self):
        """Test _get_model_settings with partial config."""
        mock_agent = MagicMock()
        model = pydantic_ai(mock_agent)

        config = ChatModelConfig(
            max_tokens=200,
            temperature=0.5,
        )

        settings: ModelSettings = model._get_model_settings(config)

        assert settings.get("max_tokens") == 200
        assert settings.get("temperature") == pytest.approx(0.5)
        assert settings.get("top_p") is None
        assert settings.get("frequency_penalty") is None
        assert settings.get("presence_penalty") is None

    def test_get_model_settings_empty_config(self):
        """Test _get_model_settings with empty config."""
        mock_agent = MagicMock()
        model = pydantic_ai(mock_agent)

        config = ChatModelConfig()

        settings: ModelSettings = model._get_model_settings(config)

        assert settings == {}

    def test_build_ui_messages_with_parts(self):
        """Test _build_ui_messages with messages that have parts."""
        from pydantic_ai.ui.vercel_ai.request_types import (
            TextUIPart,
            UIMessage,
        )

        mock_agent = MagicMock()
        model = pydantic_ai(mock_agent)

        messages = [
            ChatMessage(
                role="user",
                content="Hello",
                id="msg-1",
                parts=[TextPart(type="text", text="Hello")],
            ),
            ChatMessage(
                role="assistant",
                content="Hi there",
                id="msg-2",
                parts=[TextPart(type="text", text="Hi there")],
            ),
        ]

        ui_messages = model._build_ui_messages(messages)

        assert ui_messages == [
            UIMessage(
                id="msg-1",
                role="user",
                parts=[TextUIPart(type="text", text="Hello")],
                metadata=None,
            ),
            UIMessage(
                id="msg-2",
                role="assistant",
                parts=[TextUIPart(type="text", text="Hi there")],
                metadata=None,
            ),
        ]

    def test_build_ui_messages_without_parts(self):
        """Test _build_ui_messages falls back to content when no parts."""
        from pydantic_ai.ui.vercel_ai.request_types import (
            TextUIPart,
            UIMessage,
        )

        mock_agent = MagicMock()
        model = pydantic_ai(mock_agent)

        messages = [
            ChatMessage(
                role="user",
                content="Hello from content",
                id="msg-1",
                parts=None,
            ),
        ]

        ui_messages = model._build_ui_messages(messages)

        assert ui_messages == [
            UIMessage(
                id="msg-1",
                role="user",
                parts=[TextUIPart(type="text", text="Hello from content")],
                metadata=None,
            ),
        ]

    def test_build_ui_messages_with_empty_parts(self):
        """Test _build_ui_messages falls back to content when parts list is empty."""
        from pydantic_ai.ui.vercel_ai.request_types import (
            TextUIPart,
            UIMessage,
        )

        mock_agent = MagicMock()
        model = pydantic_ai(mock_agent)

        messages = [
            ChatMessage(
                role="user",
                content="Hello from content",
                id="msg-1",
                parts=[],
            ),
        ]

        ui_messages = model._build_ui_messages(messages)

        assert ui_messages == [
            UIMessage(
                id="msg-1",
                role="user",
                parts=[TextUIPart(type="text", text="Hello from content")],
                metadata=None,
            ),
        ]

    def test_build_ui_messages_with_unknown_parts(self) -> None:
        """Test _build_ui_messages handles unknown pydantic parts."""
        from pydantic_ai.ui.vercel_ai.request_types import (
            DynamicToolInputAvailablePart,
            UIMessage,
        )

        mock_agent = MagicMock()
        model = pydantic_ai(mock_agent)

        dynamic_tool_part = DynamicToolInputAvailablePart(
            tool_name="tool-name",
            type="dynamic-tool",
            tool_call_id="tool-call-id",
            state="input-available",
            input={"input": "input"},
            call_provider_metadata={"provider": {"name": "provider"}},
        )

        messages = [
            ChatMessage(
                role="user",
                content="Hello",
                id="msg-1",
                parts=[dynamic_tool_part],  # type: ignore
            ),
        ]

        ui_messages = model._build_ui_messages(messages)
        assert ui_messages == [
            UIMessage(
                id="msg-1",
                role="user",
                parts=[dynamic_tool_part],
                metadata=None,
            ),
        ]

    def test_build_ui_messages_with_unknown_dicts(self):
        """Test _build_ui_messages handles unknown dicts which are actually UIMessageParts."""
        from pydantic_ai.ui.vercel_ai.request_types import (
            TextUIPart,
            UIMessage,
        )

        mock_agent = MagicMock()
        model = pydantic_ai(mock_agent)

        messages = [
            ChatMessage(
                role="user",
                content="Hello",
                id="msg-1",
                parts=[{"type": "text", "text": "Hello"}],  # type: ignore
            ),
        ]

        ui_messages = model._build_ui_messages(messages)
        assert ui_messages == [
            UIMessage(
                id="msg-1",
                role="user",
                parts=[TextUIPart(type="text", text="Hello")],
            ),
        ]

    def test_build_ui_messages_generates_id_when_missing(self):
        """Test _build_ui_messages generates ID when message has no id."""
        mock_agent = MagicMock()
        model = pydantic_ai(mock_agent)

        messages = [
            ChatMessage(
                role="user",
                content="Hello",
                id=None,
                parts=[TextPart(type="text", text="Hello")],
            ),
        ]

        ui_messages = model._build_ui_messages(messages)

        assert len(ui_messages) == 1
        # ID is generated, so just check it starts with "message"
        assert ui_messages[0].id.startswith("message")
        assert ui_messages[0].role == "user"

    def test_build_ui_messages_preserves_metadata(self):
        """Test _build_ui_messages preserves message metadata."""
        from pydantic_ai.ui.vercel_ai.request_types import (
            TextUIPart,
            UIMessage,
        )

        mock_agent = MagicMock()
        model = pydantic_ai(mock_agent)

        messages = [
            ChatMessage(
                role="user",
                content="Hello",
                id="msg-1",
                parts=[TextPart(type="text", text="Hello")],
                metadata={"custom_key": "custom_value"},
            ),
        ]

        ui_messages = model._build_ui_messages(messages)

        assert ui_messages == [
            UIMessage(
                id="msg-1",
                role="user",
                parts=[TextUIPart(type="text", text="Hello")],
                metadata={"custom_key": "custom_value"},
            ),
        ]

    async def test_stream_response(self):
        """Test _stream_response streams Vercel AI events."""
        from pydantic_ai.ui.vercel_ai.response_types import TextDeltaChunk

        mock_agent = MagicMock()
        model = pydantic_ai(mock_agent)

        messages = [
            ChatMessage(
                role="user",
                content="Hello",
                id="msg-1",
                parts=[TextPart(type="text", text="Hello")],
            ),
        ]
        config = ChatModelConfig(max_tokens=100)

        # Mock the VercelAIAdapter
        with patch("pydantic_ai.ui.vercel_ai.VercelAIAdapter") as mock_adapter:
            mock_instance = MagicMock()
            mock_adapter.return_value = mock_instance

            # Use actual pydantic-ai chunk types that have model_dump
            async def mock_run_stream(**_kwargs: Any):
                yield TextDeltaChunk(id="1", type="text-delta", delta="Hello")
                yield TextDeltaChunk(id="2", type="text-delta", delta=" World")

            mock_instance.run_stream = mock_run_stream

            chunks = []
            async for chunk in model._stream_response(messages, config):
                chunks.append(chunk)

            assert chunks == [
                {"id": "1", "type": "text-delta", "delta": "Hello"},
                {"id": "2", "type": "text-delta", "delta": " World"},
            ]

    async def test_stream_text(self):
        """Test _stream_text streams text from the model."""
        mock_agent = MagicMock()
        model = pydantic_ai(mock_agent)

        messages = [
            ChatMessage(
                role="user",
                content="Hello",
                id="msg-1",
                parts=[TextPart(type="text", text="Hello")],
            ),
        ]
        config = ChatModelConfig(max_tokens=100)

        # Mock VercelAIAdapter.load_messages
        with patch("pydantic_ai.ui.vercel_ai.VercelAIAdapter") as mock_adapter:
            mock_adapter.load_messages.return_value = []

            # Mock the agent's run_stream as an async context manager
            async def mock_stream_text(delta=True):  # noqa: ARG001
                yield "Hello"
                yield " World"

            mock_result = MagicMock()
            mock_result.stream_text = mock_stream_text

            async def mock_aenter(self: Any):  # noqa: ARG001
                return mock_result

            async def mock_aexit(self: Any, *args: Any):
                pass

            mock_context = MagicMock()
            mock_context.__aenter__ = mock_aenter
            mock_context.__aexit__ = mock_aexit

            mock_agent.run_stream.return_value = mock_context

            chunks = []
            async for chunk in model._stream_text(messages, config):
                chunks.append(chunk)

            assert chunks == ["Hello", " World"]

    def test_call_returns_stream_response(self):
        """Test that __call__ returns _stream_response result."""
        mock_agent = MagicMock()
        model = pydantic_ai(mock_agent)

        messages = [
            ChatMessage(
                role="user",
                content="Hello",
                id="msg-1",
                parts=[TextPart(type="text", text="Hello")],
            ),
        ]
        config = ChatModelConfig()

        result = model(messages, config)
        # Should return an async generator
        import inspect

        assert inspect.isasyncgen(result)

    @patch(
        "marimo._dependencies.dependencies.DependencyManager.pydantic_ai.require"
    )
    def test_dependency_check(self, mock_require):
        """Test that the dependency check is performed on init."""
        mock_agent = MagicMock()
        pydantic_ai(mock_agent)

        mock_require.assert_called_once_with(
            "pydantic-ai chat model requires pydantic-ai. `pip install pydantic-ai`"
        )

    def test_build_ui_messages_with_multiple_roles(self):
        """Test _build_ui_messages with system, user, and assistant roles."""
        from pydantic_ai.ui.vercel_ai.request_types import (
            TextUIPart,
            UIMessage,
        )

        mock_agent = MagicMock()
        model = pydantic_ai(mock_agent)

        messages = [
            ChatMessage(
                role="system",
                content="You are helpful",
                id="msg-1",
                parts=[TextPart(type="text", text="You are helpful")],
            ),
            ChatMessage(
                role="user",
                content="Hello",
                id="msg-2",
                parts=[TextPart(type="text", text="Hello")],
            ),
            ChatMessage(
                role="assistant",
                content="Hi there",
                id="msg-3",
                parts=[TextPart(type="text", text="Hi there")],
            ),
        ]

        ui_messages = model._build_ui_messages(messages)

        assert ui_messages == [
            UIMessage(
                id="msg-1",
                role="system",
                parts=[TextUIPart(type="text", text="You are helpful")],
                metadata=None,
            ),
            UIMessage(
                id="msg-2",
                role="user",
                parts=[TextUIPart(type="text", text="Hello")],
                metadata=None,
            ),
            UIMessage(
                id="msg-3",
                role="assistant",
                parts=[TextUIPart(type="text", text="Hi there")],
                metadata=None,
            ),
        ]

    def test_pydantic_ai_serialize_vercel_ai_chunk(self) -> None:
        """Test _serialize_vercel_ai_chunk with valid chunks."""
        from pydantic_ai.ui.vercel_ai.response_types import (
            TextDeltaChunk,
            ToolInputAvailableChunk,
        )

        mock_agent = MagicMock()
        model = pydantic_ai(mock_agent)

        # Test text-delta chunk
        text_chunk = TextDeltaChunk(id="1", type="text-delta", delta="Hello")
        result = model._serialize_vercel_ai_chunk(text_chunk)
        assert result == {"id": "1", "type": "text-delta", "delta": "Hello"}

        # Test tool-call chunk
        tool_chunk = ToolInputAvailableChunk(
            tool_name="search",
            tool_call_id="call-1",
            input={"query": "test"},
        )
        result = model._serialize_vercel_ai_chunk(tool_chunk)
        assert result == {
            "type": "tool-input-available",
            "toolCallId": "call-1",
            "toolName": "search",
            "input": {"query": "test"},
        }

    def test_pydantic_ai_serialize_vercel_ai_chunk_done_type(self) -> None:
        """Test that 'done' type chunks are skipped."""
        from pydantic_ai.ui.vercel_ai.response_types import DoneChunk

        mock_agent = MagicMock()
        model = pydantic_ai(mock_agent)

        # Test done chunk - should return None
        done_chunk = DoneChunk(type="done")
        result = model._serialize_vercel_ai_chunk(done_chunk)
        assert result is None

    def test_pydantic_ai_serialize_vercel_ai_chunk_error_handling(
        self,
    ) -> None:
        """Test error handling in _serialize_vercel_ai_chunk."""
        from typing import cast

        mock_agent = MagicMock()
        model = pydantic_ai(mock_agent)

        # Test chunk that raises error - should return None
        error_chunk = MockBaseChunkWithError()
        result = model._serialize_vercel_ai_chunk(cast(Any, error_chunk))
        assert result is None


class MockBaseChunkWithError:
    """Mock BaseChunk that raises on serialization."""

    def model_dump(
        self, mode: str, by_alias: bool, exclude_none: bool
    ) -> dict[str, Any]:
        del mode, by_alias, exclude_none
        raise ValueError("Serialization error")


@pytest.mark.skipif(
    DependencyManager.pydantic_ai.has(), reason="Pydantic AI is installed"
)
def test_pydantic_ai_require() -> None:
    """Test that pydantic_ai.require raises ModuleNotFoundError."""
    mock_agent = MagicMock()
    with pytest.raises(ModuleNotFoundError):
        pydantic_ai(mock_agent)
