#!/usr/bin/env python3

import marimo

__generated_with = "0.18.4"
app = marimo.App(width="medium")

with app.setup:
    import marimo as mo
    import marimo._schemas.session as s
    import marimo._schemas.notebook as n
    import marimo._messaging.notification as o
    import typing as t

    SESSION_MESSAGES = [
        # Session
        s.TimeMetadata,
        s.StreamOutput,
        s.StreamMediaOutput,
        s.ErrorOutput,
        s.DataOutput,
        s.Cell,
        s.NotebookSessionMetadata,
        s.NotebookSessionV1,
    ]
    NOTEBOOK_MESSAGES = [
        # Notebook
        n.NotebookCell,
        n.NotebookCellConfig,
        n.NotebookMetadata,
        n.NotebookV1,
    ]
    OPS_MESSAGES = [
        # Message operations
        o.CellNotification,
        o.FunctionCallResultNotification,
        o.RemoveUIElementsNotification,
        o.UIElementMessageNotification,
        o.InterruptedNotification,
        o.CompletedRunNotification,
        o.KernelReadyNotification,
        o.CompletionResultNotification,
        o.AlertNotification,
        o.MissingPackageAlertNotification,
        o.InstallingPackageAlertNotification,
        o.ReconnectedNotification,
        o.StartupLogsNotification,
        o.BannerNotification,
        o.ReloadNotification,
        o.VariablesNotification,
        o.VariableValuesNotification,
        o.DatasetsNotification,
        o.SQLTablePreviewNotification,
        o.SQLTableListPreviewNotification,
        o.DataColumnPreviewNotification,
        o.DataSourceConnectionsNotification,
        o.ValidateSQLResultNotification,
        o.QueryParamsSetNotification,
        o.QueryParamsAppendNotification,
        o.QueryParamsDeleteNotification,
        o.QueryParamsClearNotification,
        o.FocusCellNotification,
        o.UpdateCellCodesNotification,
        o.SecretKeysResultNotification,
        o.CacheClearedNotification,
        o.CacheInfoNotification,
        o.UpdateCellIdsNotification,
    ]


@app.cell(hide_code=True)
def _():
    mo.md(r"""
    This marimo notebook generates the OpenAPI schema for the `TypeDict`s defined in `marimo._schemas`
    """)
    return


@app.cell(hide_code=True)
def _():
    generate_schema_button = mo.ui.run_button(label="Write schema")
    generate_schema_button
    return (generate_schema_button,)


@app.function
def generate_schema(name: str):
    import yaml

    assert name in ["session", "notebook", "notifications"], (
        "Invalid schema name must be 'session', 'notebook', or 'notifications'"
    )

    header = f"# This file is generated by scripts/generate_schemas.py\n"

    print(f"Writing schema to marimo/_schemas/generated/{name}.yaml")
    if name == "session":
        messages = SESSION_MESSAGES
    elif name == "notebook":
        messages = NOTEBOOK_MESSAGES
    else:
        messages = OPS_MESSAGES
    return header + yaml.dump(build_openapi_schema(messages))


@app.function
def write_schema(name: str):
    output = (
        mo.notebook_dir().parent
        / "marimo"
        / "_schemas"
        / "generated"
        / f"{name}.yaml"
    )
    output.write_text(generate_schema(name), encoding="utf-8")


@app.function
def build_openapi_schema(messages):
    from marimo._utils.dataclass_to_openapi import (
        PythonTypeToOpenAPI,
    )
    import msgspec
    import msgspec.json

    try:
        components = msgspec.json.schema_components(messages, ref_template="#/components/schemas/{name}")
        return {
            "openapi": "3.0.0",
            "info": {"title": "marimo"},
            "components": {
                "schemas": {
                    **components[1],
                }
            },
        }
    except Exception as e:
        import traceback

        traceback.print_exc()

    processed_classes: dict[t.Any, str] = {}
    component_schemas: dict[str, t.Any] = {}
    name_overrides: dict[t.Any, str] = {}

    converter = PythonTypeToOpenAPI(
        camel_case=False, name_overrides=name_overrides
    )
    for cls in messages:
        if cls in processed_classes:
            del processed_classes[cls]
        name = name_overrides.get(cls, cls.__name__)  # type: ignore[attr-defined]
        component_schemas[name] = converter.convert(cls, processed_classes)
        processed_classes[cls] = name

    schemas = {
        "openapi": "3.0.0",
        "info": {"title": "marimo"},
        "components": {
            "schemas": {
                **component_schemas,
            }
        },
    }
    return schemas


@app.cell
def _(generate_schema_button):
    if mo.app_meta().mode == "script" or generate_schema_button.value:
        write_schema("session")
        write_schema("notebook")
        write_schema("notifications")
    return


if __name__ == "__main__":
    app.run()
