# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any

import pytest

from haystack.core.component import component
from haystack.core.component.component import ComponentError


@component
class ValidComponent:
    def run(self, text: str) -> dict[str, Any]:
        return {"result": text}

    async def run_async(self, text: str) -> dict[str, Any]:
        return {"result": text}


@component
class DifferentParamNameComponent:
    def run(self, text: str) -> dict[str, Any]:
        return {"result": text}

    async def run_async(self, input_text: str) -> dict[str, Any]:
        return {"result": input_text}


@component
class DifferentParamTypeComponent:
    def run(self, text: str) -> dict[str, Any]:
        return {"result": text}

    async def run_async(self, text: list[str]) -> dict[str, Any]:
        return {"result": text[0]}


@component
class DifferentDefaultValueComponent:
    def run(self, text: str = "default") -> dict[str, Any]:
        return {"result": text}

    async def run_async(self, text: str = "different") -> dict[str, Any]:
        return {"result": text}


@component
class DifferentParamKindComponent:
    def run(self, text: str) -> dict[str, Any]:
        return {"result": text}

    async def run_async(self, *, text: str) -> dict[str, Any]:
        return {"result": text}


@component
class DifferentParamCountComponent:
    def run(self, text: str) -> dict[str, Any]:
        return {"result": text}

    async def run_async(self, text: str, extra: str) -> dict[str, Any]:
        return {"result": text + extra}


def test_valid_signatures():
    component = ValidComponent()
    assert component.run("test") == {"result": "test"}


def test_different_param_names():
    with pytest.raises(ComponentError, match="name mismatch"):
        DifferentParamNameComponent()


def test_different_param_types():
    with pytest.raises(ComponentError, match="type mismatch"):
        DifferentParamTypeComponent()


def test_different_default_values():
    with pytest.raises(ComponentError, match="default value mismatch"):
        DifferentDefaultValueComponent()


def test_different_param_kinds():
    with pytest.raises(ComponentError, match=r"kind \(POSITIONAL, KEYWORD, etc\.\) mismatch: "):
        DifferentParamKindComponent()


def test_different_param_count():
    with pytest.raises(ComponentError, match="Different number of parameters"):
        DifferentParamCountComponent()
