import logging
import uuid
from asgiref.sync import sync_to_async
from functools import cached_property
from typing import Any, Dict, Type

from django.db import models
from langchain.schema.runnable import Runnable
from pydantic.v1 import BaseModel

from ix.chains.fixture_src.flow import ROOT_CLASS_PATH
from ix.ix_users.models import OwnedModel
from ix.pg_vector.tests.models import PGVectorMixin
from ix.pg_vector.utils import get_embedding
from ix.utils.pydantic import create_args_model_v1

logger = logging.getLogger(__name__)


class NodeTypeQuery(PGVectorMixin, models.QuerySet):
    """Mixing PGVectorMixin into the default QuerySet."""

    pass


class NodeTypeManager(models.Manager.from_queryset(NodeTypeQuery)):
    def get_by_natural_key(self, class_path):
        return self.get(class_path=class_path)

    def create_with_embedding(self, name, description, class_path):
        """
        Creates a new NodeType object with a vector embedding generated
        from the given text using OpenAI's API.
        """
        text = f"{name} {description} {class_path}"
        embedding = get_embedding(text)
        return self.create(
            name=name,
            description=description,
            class_path=class_path,
            embedding=embedding,
        )


class NodeType(OwnedModel):
    TYPES = [
        ("agent", "agent"),
        ("chain", "chain"),
        ("chain_list", "chain_list"),
        ("document_loader", "document_loader"),
        ("embeddings", "embeddings"),
        ("index", "index"),
        ("llm", "llm"),
        ("memory", "memory"),
        ("memory_backend", "memory_backend"),
        ("prompt", "prompt"),
        ("retriever", "retriever"),
        ("tool", "tool"),
        ("toolkit", "toolkit"),
        ("text_splitter", "text_splitter"),
    ]

    # info
    id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
    name = models.CharField(max_length=255)
    description = models.TextField(null=True)
    class_path = models.CharField(max_length=255, unique=True)
    type = models.CharField(max_length=255)

    # deprecated
    display_type = models.CharField(
        max_length=10,
        default="node",
        choices=(("node", "node"), ("list", "list"), ("map", "map")),
    )

    # structure
    connectors = models.JSONField(null=True)
    fields = models.JSONField(null=True)
    field_groups = models.JSONField(null=True)

    # variable to load context in when loading node. Generally only used by IX
    # internal components like ChainReference that need to load a chain.
    context = models.CharField(max_length=32, null=True)

    # child_field is the name of the field that contains child nodes
    # used for parsing config objects
    child_field = models.CharField(max_length=32, null=True)

    # JSONSchema for the config object
    config_schema = models.JSONField(default=dict)

    objects = NodeTypeManager()

    @cached_property
    def connectors_as_dict(self):
        return {c["key"]: c for c in self.connectors or []}

    def __str__(self):
        return f"{self.class_path}"

    def natural_key(self):
        return (self.class_path,)


def default_position():
    return {"x": 0, "y": 0}


class ChainNodeManager(models.Manager):
    def create_from_config(
        self, chain, config: Dict[str, Any], root=False, parent=None
    ) -> "ChainNode":
        """
        Create an instance from a config dict.

        This method will identify the NodeType from the class_path. The NodeType
        definition is used to recursively identify and parse nested property nodes
        and child nodes.
        """
        # create copy of config since it will be mutated
        config = config.copy()

        # get the node type
        class_path = config["class_path"]
        logger.debug(f"creating node from config class_path={class_path}")

        try:
            node_type = NodeType.objects.get(class_path=class_path)
        except NodeType.DoesNotExist:
            logger.error(f"NodeType with class_path={class_path} does not exist")
            raise

        # pop off nested and child nodes before creating node
        node_config = config.pop("config", {}).copy()
        property_configs = {}
        child_configs = []
        for connector in node_type.connectors or []:
            if connector["type"] == "target" and connector["key"] in node_config:
                logger.debug(f"adding property key={connector['key']}")
                property_configs[connector["key"]] = node_config.pop(connector["key"])

        if node_type.child_field is not None:
            child_configs = property_configs.pop(node_type.child_field, [])

        # create this node if visible
        is_hidden = config.pop("hidden", False)
        if not is_hidden:
            node = self.create(
                chain=chain,
                node_type=node_type,
                root=root,
                position={"x": 0, "y": 0},
                config=node_config,
                **config,
            )

            # create nested property nodes and edges to them
            for key, property_config_group in property_configs.items():
                logger.debug(f"creating property node for key={key}")
                if not isinstance(property_config_group, list):
                    property_config_group = [property_config_group]

                for property_config in property_config_group:
                    nested_node = self.create_from_config(
                        chain=chain, config=property_config
                    )
                    ChainEdge.objects.create(
                        chain_id=node.chain_id,
                        source=nested_node,
                        source_key=nested_node.node_type.type,
                        target=node,
                        relation="PROP",
                        target_key=key,
                    )
        elif property_configs:
            logger.error(
                f"class_path={class_path} has properties but is not hidden, properties={property_configs}"
            )
            raise ValueError("hidden nodes cannot have properties")

        # Handle children: Nodes with children may be hidden or visible
        # Hidden nodes are used with SequentialNodes to simplify the graph
        # UX. The children are visible and linked together. SequentialNodes
        # when visible display the children as a property node. This allows
        # both a simplified graph where nodes are linked together, and also
        # supports adding common properties to the SequentialNode when needed.
        if node_type.child_field is not None:
            logger.debug(
                f"node_id={node.id} loading children from field={node_type.child_field}"
            )
            latest_child = None
            for i, child in enumerate(child_configs):
                logger.debug(
                    f"node_id={node.id} creating child i={i} child={class_path}"
                )

                # create child
                source_node = latest_child
                latest_child = self.create_from_config(
                    chain=chain, config=child, root=root and i == 0 and is_hidden
                )

                # Link adjacent siblings
                if source_node:
                    ChainEdge.objects.create(
                        chain=chain,
                        source=source_node,
                        target=latest_child,
                        source_key="out",
                        target_key="in",
                        relation="LINK",
                    )

                # Add first node as property when visible
                if i == 0 and not is_hidden:
                    ChainEdge.objects.create(
                        chain=chain,
                        source=latest_child,
                        target=node,
                        relation="PROP",
                        source_key=latest_child.node_type.type,
                        target_key=node_type.child_field,
                    )

        logger.debug(f"created node_id={node.id} class_path={node.class_path}")
        return node


class ChainNode(models.Model):
    id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
    class_path = models.CharField(max_length=255)
    node_type = models.ForeignKey(NodeType, on_delete=models.CASCADE, null=True)
    config = models.JSONField(null=True, default=dict)
    name = models.CharField(max_length=255, null=True)
    description = models.TextField(null=True)

    # node is root of graph
    root = models.BooleanField(default=False)

    # graph position
    position = models.JSONField(default=default_position)

    # parent chain
    chain = models.ForeignKey(
        "Chain",
        on_delete=models.CASCADE,
        related_name="nodes",
        null=True,
        blank=True,
    )

    incoming_edges: models.QuerySet["ChainEdge"]
    outgoing_edges: models.QuerySet["ChainEdge"]
    DoesNotExist: Type[models.ObjectDoesNotExist]

    objects = ChainNodeManager()

    def __str__(self):
        return f"{str(self.id)[:8]} ({self.class_path})"


class ChainEdge(models.Model):
    RELATION_CHOICES = (("PROP", "prop"), ("LINK", "link"))

    id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
    source = models.ForeignKey(
        ChainNode, on_delete=models.CASCADE, related_name="outgoing_edges"
    )
    target = models.ForeignKey(
        ChainNode, on_delete=models.CASCADE, related_name="incoming_edges"
    )

    source_key = models.CharField(max_length=255, null=True)
    target_key = models.CharField(max_length=255, null=True)

    chain = models.ForeignKey(
        "Chain", on_delete=models.CASCADE, related_name="edges", null=True
    )
    input_map = models.JSONField(null=True)
    relation = models.CharField(
        max_length=5, null=True, choices=RELATION_CHOICES, default="LINK"
    )

    DoesNotExist: Type[models.ObjectDoesNotExist]


class Chain(OwnedModel):
    """
    A named chain that can be run by an Agent.

    Each chain has a root ChainNode representing the start of the chain.
    """

    id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
    name = models.CharField(max_length=128)
    description = models.TextField()
    created_at = models.DateTimeField(auto_now_add=True)

    # Indicate that this chain is an agent. This is used to record the config choice.
    # The endpoints are responsible for ensuring that the agent does or does not exist.
    is_agent = models.BooleanField(default=True)

    nodes: models.QuerySet[ChainNode]

    @property
    def root(self) -> ChainNode:
        try:
            return self.nodes.get(root=True)
        except ChainNode.DoesNotExist:
            raise ValueError(f"Chain chain_id={self.id} does not have a root node")

    def __str__(self):
        return f"{self.name} ({self.id})"

    def load_chain(self, context: "IxContext") -> Runnable:  # noqa: F821
        from ix.chains.loaders.core import init_chain_flow

        return init_chain_flow(self, context=context)

    async def aload_chain(self, context: "IxContext") -> Runnable:  # noqa: F821
        from ix.chains.loaders.core import init_chain_flow

        return await sync_to_async(init_chain_flow)(self, context=context)

    def clear_chain(self):
        """removes the chain nodes associated with this chain"""
        # clear old chain
        ChainNode.objects.filter(chain_id=self.id).delete()

    @cached_property
    def chat_root(self):
        return self.nodes.get(root=True, class_path=ROOT_CLASS_PATH)

    @cached_property
    def types(self) -> Type[BaseModel]:
        """Build pydantic model for chain input."""
        try:
            root = self.chat_root
            input_type = create_args_model_v1(
                root.config.get("outputs", []), name="ChainInput"
            )
            config_type = create_args_model_v1(
                root.config.get("config", []), name="ChainConfig"
            )
        except ChainNode.DoesNotExist:
            # fallback to old style roots:
            # TODO: remove this fallback after all chains have been migrated
            input_type = create_args_model_v1(
                ["user_input", "artifact_ids"], name="ChainInput"
            )
            config_type = create_args_model_v1([], name="ChainConfig")

        class ChainConfig(BaseModel):
            input: input_type
            config: config_type = {}

            INPUT = input_type
            CONFIG = config_type

        return ChainConfig
