# Copyright © 2026 Pathway

import base64
import datetime
import json
import pathlib
import threading
import time
import uuid

import pytest

import pathway as pw
from pathway.internals.parse_graph import G
from pathway.tests.utils import (
    FileLinesNumberChecker,
    expect_csv_checker,
    wait_result_with_checker,
)

from .utils import (
    SCHEMA_REGISTRY_BASE_ROUTE,
    KafkaTestContext,
    create_schema_in_registry,
)


@pytest.mark.parametrize("with_metadata", [False, True])
@pytest.mark.flaky(reruns=3)
def test_kafka_raw(with_metadata, tmp_path, kafka_context):
    kafka_context.fill(["foo", "bar"])

    table = pw.io.kafka.read(
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic=kafka_context.input_topic,
        format="plaintext",
        autocommit_duration_ms=100,
        with_metadata=with_metadata,
    )

    pw.io.csv.write(table, tmp_path / "output.csv")

    wait_result_with_checker(
        expect_csv_checker(
            """
            data
            foo
            bar
            """,
            tmp_path / "output.csv",
            usecols=["data"],
            index_col=["data"],
        ),
        10,
    )


@pytest.mark.parametrize("with_metadata", [False, True])
@pytest.mark.parametrize("input_format", ["plaintext", "raw"])
@pytest.mark.flaky(reruns=3)
def test_kafka_key_parsing(input_format, with_metadata, tmp_path, kafka_context):
    context = [
        ("1", "one"),
        ("2", "two"),
        ("3", "three"),
        ("4", None),
        (None, "five"),
    ]
    kafka_context.fill(context)

    output_path = tmp_path / "output.jsonl"
    table = pw.io.kafka.read(
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic=kafka_context.input_topic,
        format=input_format,
        autocommit_duration_ms=100,
        with_metadata=with_metadata,
        mode="static",
    )
    pw.io.jsonlines.write(table, output_path)
    pw.run()

    parsed_values = []
    with open(output_path, "r") as f:
        for row in f:
            data = json.loads(row)
            key = data["key"]
            value = data["data"]
            if input_format == "raw" and key is not None:
                key = base64.b64decode(key).decode("utf-8")
            if input_format == "raw" and value is not None:
                value = base64.b64decode(value).decode("utf-8")
            parsed_values.append((key, value))
    parsed_values.sort(key=lambda data: str(data[0]))
    context.sort(key=lambda data: str(data[0]))
    assert parsed_values == context


@pytest.mark.flaky(reruns=3)
def test_kafka_static_mode(tmp_path, kafka_context):
    kafka_context.fill(["foo", "bar"])

    table = pw.io.kafka.read(
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic=kafka_context.input_topic,
        format="plaintext",
        autocommit_duration_ms=100,
        mode="static",
    )
    pw.io.jsonlines.write(table, tmp_path / "output.jsonl")
    pw.run()
    result = set()
    with open(tmp_path / "output.jsonl", "r") as f:
        for row in f:
            result.add(json.loads(row)["data"])
    assert result == set({"foo", "bar"})


@pytest.mark.flaky(reruns=3)
def test_kafka_message_metadata(tmp_path, kafka_context):
    test_kafka_foo_message_headers = [
        ("X-Sender-ID", b"pathway-integration-test"),
        ("X-Trace-ID", b"a8acf0a5-009f-4035-9aca-834bc85929f9"),
        ("X-Trace-ID", b"7a21cee9-c081-4d64-add1-06e2e5e592d6"),
        ("X-Origin", b""),
        ("X-Signature", bytes([0, 255, 128, 10])),
    ]
    test_kafka_bar_message_headers = [
        ("X-Sender-ID", b"pathway-integration-test"),
        ("X-Trace-ID", b"ee6e3017-d77f-43d9-abf6-c33bd51e27ef"),
        ("X-Trace-ID", b"092565ae-aa1e-406c-a53f-d2c4d6f2397c"),
        ("X-Trace-ID", b"1d0ae9e7-8cac-40d8-9072-3d1a919a2fef"),
        ("X-Origin", b"Server"),
        ("X-Signature", bytes([0, 255, 128, 10, 17])),
    ]

    def check_headers(parsed: list[list[str]], original: list[tuple[str, bytes]]):
        decoded_headers = []
        for key, value in parsed:
            decoded_value = base64.b64decode(value)
            decoded_headers.append((key, decoded_value))
        decoded_headers.sort()
        original.sort()
        assert decoded_headers == original

    kafka_context.fill(["foo"], headers=test_kafka_foo_message_headers)
    kafka_context.fill(["bar"], headers=test_kafka_bar_message_headers)

    table = pw.io.kafka.read(
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic=kafka_context.input_topic,
        format="plaintext",
        autocommit_duration_ms=100,
        with_metadata=True,
    )
    output_path = tmp_path / "output.jsonl"

    pw.io.jsonlines.write(table, output_path)
    wait_result_with_checker(FileLinesNumberChecker(output_path, 2), 10)

    offsets = set()
    with open(output_path, "r") as f:
        for row in f:
            data = json.loads(row)
            metadata = data["_metadata"]
            assert metadata["topic"] == kafka_context.input_topic
            assert "partition" in metadata
            assert "offset" in metadata
            offsets.add(metadata["offset"])

            assert "headers" in metadata
            headers = metadata["headers"]
            if data["data"] == "foo":
                check_headers(headers, test_kafka_foo_message_headers)
            elif data["data"] == "bar":
                check_headers(headers, test_kafka_bar_message_headers)
            else:
                raise ValueError(f"unknown message data: {data['data']}")

    assert len(offsets) == 2


# Python client for Kafka doesn't allow null header body, while it's still allowed
# by the protocol. Hence we test it, but differently.
def test_null_header(tmp_path, kafka_context):
    output_path = tmp_path / "output.jsonl"
    kafka_context.fill(
        [
            json.dumps({"k": 0, "hdr": "foo"}),
            json.dumps(
                {"k": 1, "hdr": None}
            ),  # We output this as a header having no value
            json.dumps({"k": 2, "hdr": "bar"}),
        ]
    )

    class InputSchema(pw.Schema):
        k: int = pw.column_definition(primary_key=True)
        hdr: str | None

    table = pw.io.kafka.read(
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic=kafka_context.input_topic,
        format="json",
        mode="static",
        schema=InputSchema,
    )
    pw.io.kafka.write(
        table,
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic_name=kafka_context.output_topic,
        format="json",
        headers=[pw.this.hdr],
    )
    pw.run()
    G.clear()

    table = pw.io.kafka.read(
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic=kafka_context.output_topic,
        format="json",
        mode="static",
        schema=InputSchema,
        with_metadata=True,
    )
    pw.io.jsonlines.write(table, output_path)
    pw.run()

    n_rows = 0
    with open(output_path, "r") as f:
        for row in f:
            data = json.loads(row)
            key = data["k"]
            metadata = data["_metadata"]
            headers = [
                h for h in metadata["headers"] if not h[0].startswith("pathway_")
            ]
            assert len(headers) == 1
            header_key, header_value = headers[0]
            header_value = (
                base64.b64decode(header_value) if header_value is not None else None
            )
            assert header_key == "hdr"
            if key == 0:
                assert header_value == b"foo"
            elif key == 1:
                assert header_value is None
            elif key == 2:
                assert header_value == b"bar"
            else:
                raise ValueError(f"unknown key: {key}")
            n_rows += 1

    assert n_rows == 3


@pytest.mark.parametrize("with_metadata", [False, True])
@pytest.mark.flaky(reruns=3)
def test_kafka_json(tmp_path, kafka_context, with_metadata):
    kafka_context.fill(
        [
            json.dumps({"k": 0, "v": "foo"}),
            json.dumps({"k": 1, "v": "bar"}),
            json.dumps({"k": 2, "v": "baz"}),
        ]
    )

    class InputSchema(pw.Schema):
        k: int = pw.column_definition(primary_key=True)
        v: str

    table = pw.io.kafka.read(
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic=kafka_context.input_topic,
        format="json",
        schema=InputSchema,
        with_metadata=with_metadata,
        autocommit_duration_ms=100,
    )

    pw.io.csv.write(table, tmp_path / "output.csv")

    wait_result_with_checker(
        expect_csv_checker(
            """
            k    | v
            0    | foo
            1    | bar
            2    | baz
            """,
            tmp_path / "output.csv",
            usecols=["v"],
            index_col=["k"],
        ),
        10,
    )


@pytest.mark.parametrize("with_metadata", [False, True])
@pytest.mark.flaky(reruns=3)
def test_kafka_json_key_parsing(tmp_path, kafka_context, with_metadata):
    context = [
        (json.dumps({"k": 0}), json.dumps({"v": "foo"})),
        (json.dumps({"k": 1}), json.dumps({"v": "bar"})),
        (json.dumps({"k": 2}), json.dumps({"v": "baz"})),
    ]
    kafka_context.fill(context)

    class InputSchema(pw.Schema):
        k: int = pw.column_definition(primary_key=True, source_component="key")
        v: str = pw.column_definition(primary_key=True, source_component="payload")

    table = pw.io.kafka.read(
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic=kafka_context.input_topic,
        format="json",
        schema=InputSchema,
        with_metadata=with_metadata,
        autocommit_duration_ms=100,
    )

    pw.io.csv.write(table, tmp_path / "output.csv")

    wait_result_with_checker(
        expect_csv_checker(
            """
            k    | v
            0    | foo
            1    | bar
            2    | baz
            """,
            tmp_path / "output.csv",
            usecols=["v"],
            index_col=["k"],
        ),
        10,
    )


@pytest.mark.parametrize("with_metadata", [False, True])
@pytest.mark.flaky(reruns=3)
def test_kafka_json_key_jsonpaths(tmp_path, kafka_context, with_metadata):
    context = [
        (json.dumps({"k": {"l": 0, "m": 3}}), json.dumps({"v": {"vv": "foo"}})),
        (json.dumps({"k": {"l": 1, "m": 4}}), json.dumps({"v": {"vv": "bar"}})),
        (json.dumps({"k": {"l": 2, "m": 5}}), json.dumps({"v": {"vv": "baz"}})),
    ]
    kafka_context.fill(context)

    class InputSchema(pw.Schema):
        k: int = pw.column_definition(primary_key=True, source_component="key")
        v: str = pw.column_definition(primary_key=True, source_component="payload")

    table = pw.io.kafka.read(
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic=kafka_context.input_topic,
        format="json",
        schema=InputSchema,
        with_metadata=with_metadata,
        autocommit_duration_ms=100,
        json_field_paths={"k": "/k/l", "v": "/v/vv"},
    )

    pw.io.csv.write(table, tmp_path / "output.csv")

    wait_result_with_checker(
        expect_csv_checker(
            """
            k    | v
            0    | foo
            1    | bar
            2    | baz
            """,
            tmp_path / "output.csv",
            usecols=["v"],
            index_col=["k"],
        ),
        10,
    )


@pytest.mark.parametrize("with_metadata", [False, True])
@pytest.mark.parametrize("unparsable_value", ["abracadabra", None])
@pytest.mark.flaky(reruns=3)
def test_kafka_json_data_only_in_key(
    tmp_path, unparsable_value, kafka_context, with_metadata
):
    context = [
        (json.dumps({"k": 0, "v": "foo"}), unparsable_value),
        (json.dumps({"k": 1, "v": "bar"}), unparsable_value),
        (json.dumps({"k": 2, "v": "baz"}), unparsable_value),
    ]
    kafka_context.fill(context)

    class InputSchema(pw.Schema):
        k: int = pw.column_definition(primary_key=True, source_component="key")
        v: str = pw.column_definition(primary_key=True, source_component="key")

    table = pw.io.kafka.read(
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic=kafka_context.input_topic,
        format="json",
        schema=InputSchema,
        with_metadata=with_metadata,
        autocommit_duration_ms=100,
    )

    pw.io.csv.write(table, tmp_path / "output.csv")

    wait_result_with_checker(
        expect_csv_checker(
            """
            k    | v
            0    | foo
            1    | bar
            2    | baz
            """,
            tmp_path / "output.csv",
            usecols=["v"],
            index_col=["k"],
        ),
        10,
    )


@pytest.mark.flaky(reruns=3)
def test_kafka_simple_wrapper_bytes_io(
    tmp_path: pathlib.Path, kafka_context: KafkaTestContext
):
    kafka_context.fill(["foo", "bar"])

    table = pw.io.kafka.simple_read(
        kafka_context.default_rdkafka_settings()["bootstrap.servers"],
        kafka_context.input_topic,
    )
    pw.io.jsonlines.write(table, tmp_path / "output.jsonl")
    wait_result_with_checker(FileLinesNumberChecker(tmp_path / "output.jsonl", 2), 10)

    # check that reread will have all these messages again
    G.clear()
    table = pw.io.kafka.simple_read(
        kafka_context.default_rdkafka_settings()["bootstrap.servers"],
        kafka_context.input_topic,
    )
    pw.io.jsonlines.write(table, tmp_path / "output.jsonl")
    wait_result_with_checker(FileLinesNumberChecker(tmp_path / "output.jsonl", 2), 10)

    # Check output type, bytes should be rendered as an array
    with open(tmp_path / "output.jsonl", "r") as f:
        for row in f:
            row_parsed = json.loads(row)
            assert isinstance(row_parsed["data"], str)
            decoded = base64.b64decode(row_parsed["data"])
            assert decoded in (b"foo", b"bar")


@pytest.mark.flaky(reruns=3)
def test_kafka_simple_wrapper_plaintext_io(
    tmp_path: pathlib.Path, kafka_context: KafkaTestContext
):
    kafka_context.fill(["foo", "bar"])

    table = pw.io.kafka.simple_read(
        kafka_context.default_rdkafka_settings()["bootstrap.servers"],
        kafka_context.input_topic,
        format="plaintext",
    )
    pw.io.jsonlines.write(table, tmp_path / "output.jsonl")
    wait_result_with_checker(FileLinesNumberChecker(tmp_path / "output.jsonl", 2), 10)

    # check that reread will have all these messages again
    G.clear()
    table = pw.io.kafka.simple_read(
        kafka_context.default_rdkafka_settings()["bootstrap.servers"],
        kafka_context.input_topic,
        format="plaintext",
    )
    pw.io.jsonlines.write(table, tmp_path / "output.jsonl")
    wait_result_with_checker(FileLinesNumberChecker(tmp_path / "output.jsonl", 2), 10)

    # Check output type, parsed plaintext should be a string
    with open(tmp_path / "output.jsonl", "r") as f:
        for row in f:
            row_parsed = json.loads(row)
            assert isinstance(row_parsed["data"], str)
            assert row_parsed["data"] == "foo" or row_parsed["data"] == "bar"


@pytest.mark.flaky(reruns=3)
def test_kafka_output(tmp_path: pathlib.Path, kafka_context: KafkaTestContext):
    input_path = tmp_path / "input"
    with open(input_path, "w") as f:
        f.write("foo\nbar\n")

    table = pw.io.plaintext.read(
        str(input_path),
        mode="static",
    )
    pw.io.kafka.write(
        table,
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic_name=kafka_context.output_topic,
    )
    pw.run()

    output_topic_contents = kafka_context.read_output_topic()
    assert len(output_topic_contents) == 2


@pytest.mark.flaky(reruns=3)
def test_kafka_raw_bytes_output(
    tmp_path: pathlib.Path, kafka_context: KafkaTestContext
):
    input_path = tmp_path / "input"
    input_path.mkdir()
    (input_path / "foo").write_text("foo")
    (input_path / "bar").write_text("bar")

    table = pw.io.fs.read(
        input_path,
        mode="static",
        format="binary",
    )
    pw.io.kafka.write(
        table,
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic_name=kafka_context.output_topic,
        format="raw",
    )
    pw.run()

    output_topic_contents = kafka_context.read_output_topic()
    assert len(output_topic_contents) == 2


def get_test_binary_table(tmp_path):
    input_path = tmp_path / "input"
    input_path.mkdir()
    (input_path / "foo").write_text("foo")
    (input_path / "bar").write_text("bar")

    return pw.io.fs.read(
        input_path,
        mode="static",
        format="binary",
        with_metadata=True,
    )


@pytest.mark.flaky(reruns=3)
@pytest.mark.parametrize(
    "key",
    [None, "data"],
)
@pytest.mark.parametrize(
    "headers",
    [
        [],
        ["data"],
        ["data", "_metadata"],
    ],
)
def test_kafka_raw_bytes_output_select_index(
    key, headers, tmp_path: pathlib.Path, kafka_context: KafkaTestContext
):
    def construct_raw_write_argument(table, name):
        if name is None:
            return None
        return table[name]

    def get_expected_headers(headers):
        expected_headers = ["pathway_time", "pathway_diff"]
        expected_headers.extend(headers)
        return expected_headers

    table = get_test_binary_table(tmp_path)
    pw.io.kafka.write(
        table,
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic_name=kafka_context.output_topic,
        format="raw",
        value=table.data,
        key=construct_raw_write_argument(table, key),
        headers=[construct_raw_write_argument(table, header) for header in headers],
    )
    pw.run()
    output_topic_contents = kafka_context.read_output_topic(
        expected_headers=get_expected_headers(headers)
    )
    assert len(output_topic_contents) == 2


@pytest.mark.flaky(reruns=3)
def test_kafka_output_rename_headers(
    tmp_path: pathlib.Path, kafka_context: KafkaTestContext
):
    table = get_test_binary_table(tmp_path)
    pw.io.kafka.write(
        table,
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic_name=kafka_context.output_topic,
        format="raw",
        key=pw.this.data,
        value=pw.this.data,
        headers=[*table.select(foo=pw.this.data, bar=pw.this._metadata)],
    )
    pw.run()
    output_topic_contents = kafka_context.read_output_topic(
        expected_headers=["pathway_time", "pathway_diff", "foo", "bar"]
    )
    assert len(output_topic_contents) == 2


@pytest.mark.flaky(reruns=3)
def test_kafka_plaintext_output(
    tmp_path: pathlib.Path, kafka_context: KafkaTestContext
):
    input_path = tmp_path / "input"
    input_path.mkdir()
    (input_path / "foo").write_text("foo")
    (input_path / "bar").write_text("bar")

    table = pw.io.fs.read(
        input_path,
        mode="static",
        format="plaintext",
    )
    pw.io.kafka.write(
        table,
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic_name=kafka_context.output_topic,
        format="plaintext",
    )
    pw.run()

    output_topic_contents = kafka_context.read_output_topic()
    assert len(output_topic_contents) == 2


@pytest.mark.flaky(reruns=3)
def test_kafka_recovery(tmp_path: pathlib.Path, kafka_context: KafkaTestContext):
    persistent_storage_path = tmp_path / "PStorage"

    kafka_context.fill(
        [
            json.dumps({"k": 0, "v": "foo"}),
            json.dumps({"k": 1, "v": "bar"}),
            json.dumps({"k": 2, "v": "baz"}),
        ]
    )

    table = pw.io.kafka.read(
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic=kafka_context.input_topic,
        format="json",
        schema=pw.schema_builder(
            columns={
                "k": pw.column_definition(dtype=int, primary_key=True),
                "v": pw.column_definition(dtype=str),
            }
        ),
        autocommit_duration_ms=100,
        name="1",
    )

    pw.io.csv.write(table, tmp_path / "output.csv")

    wait_result_with_checker(
        expect_csv_checker(
            """
            k    | v
            0    | foo
            1    | bar
            2    | baz
            """,
            tmp_path / "output.csv",
            usecols=["v"],
            index_col=["k"],
        ),
        10,
        kwargs={
            "persistence_config": pw.persistence.Config(
                pw.persistence.Backend.filesystem(persistent_storage_path),
            ),
        },
    )
    G.clear()

    # fill doesn't replace the messages, so we append 3 new ones
    kafka_context.fill(
        [
            json.dumps({"k": 3, "v": "foofoo"}),
            json.dumps({"k": 4, "v": "barbar"}),
            json.dumps({"k": 5, "v": "bazbaz"}),
        ]
    )

    table = pw.io.kafka.read(
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic=kafka_context.input_topic,
        format="json",
        schema=pw.schema_builder(
            columns={
                "k": pw.column_definition(dtype=int, primary_key=True),
                "v": pw.column_definition(dtype=str),
            }
        ),
        autocommit_duration_ms=100,
        name="1",
    )

    pw.io.csv.write(table, tmp_path / "output_backfilled.csv")
    wait_result_with_checker(
        expect_csv_checker(
            """
            k    | v
            3    | foofoo
            4    | barbar
            5    | bazbaz
            """,
            tmp_path / "output_backfilled.csv",
            usecols=["v"],
            index_col=["k"],
        ),
        10,
        target=pw.run,
        kwargs={
            "persistence_config": pw.persistence.Config(
                pw.persistence.Backend.filesystem(persistent_storage_path),
            ),
        },
    )


@pytest.mark.flaky(reruns=3)
def test_start_from_timestamp_ms_seek_to_middle(
    tmp_path: pathlib.Path, kafka_context: KafkaTestContext
):
    kafka_context.fill(["foo", "bar"])
    time.sleep(10)
    start_from_timestamp_ms = (int(time.time()) - 5) * 1000
    kafka_context.fill(["qqq", "www"])

    table = pw.io.kafka.read(
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic=kafka_context.input_topic,
        format="plaintext",
        autocommit_duration_ms=100,
        start_from_timestamp_ms=start_from_timestamp_ms,
    )

    pw.io.csv.write(table, tmp_path / "output.csv")

    wait_result_with_checker(
        expect_csv_checker(
            """
            data
            qqq
            www
            """,
            tmp_path / "output.csv",
            usecols=["data"],
            index_col=["data"],
        ),
        10,
    )


@pytest.mark.flaky(reruns=3)
def test_start_from_timestamp_ms_seek_to_beginning(
    tmp_path: pathlib.Path, kafka_context: KafkaTestContext
):
    kafka_context.fill(["foo", "bar"])
    start_from_timestamp_ms = (int(time.time()) - 3600) * 1000

    table = pw.io.kafka.read(
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic=kafka_context.input_topic,
        format="plaintext",
        autocommit_duration_ms=100,
        start_from_timestamp_ms=start_from_timestamp_ms,
    )

    pw.io.csv.write(table, tmp_path / "output.csv")

    wait_result_with_checker(
        expect_csv_checker(
            """
            data
            foo
            bar
            """,
            tmp_path / "output.csv",
            usecols=["data"],
            index_col=["data"],
        ),
        10,
    )


@pytest.mark.flaky(reruns=3)
def test_start_from_timestamp_ms_seek_to_end(
    tmp_path: pathlib.Path, kafka_context: KafkaTestContext
):
    kafka_context.fill(["foo", "bar"])
    time.sleep(10)
    start_from_timestamp_ms = int(time.time() - 5) * 1000

    table = pw.io.kafka.read(
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic=kafka_context.input_topic,
        format="plaintext",
        autocommit_duration_ms=100,
        start_from_timestamp_ms=start_from_timestamp_ms,
    )

    def stream_inputs():
        for i in range(10):
            kafka_context.fill([str(i)])
            time.sleep(1)

    t = threading.Thread(target=stream_inputs, daemon=True)
    t.run()

    pw.io.csv.write(table, tmp_path / "output.csv")

    wait_result_with_checker(
        expect_csv_checker(
            """
            data
            0
            1
            2
            3
            4
            5
            6
            7
            8
            9
            """,
            tmp_path / "output.csv",
            usecols=["data"],
            index_col=["data"],
        ),
        30,
    )


@pytest.mark.flaky(reruns=3)
def test_kafka_json_key(tmp_path, kafka_context):
    input_path = tmp_path / "input.jsonl"
    with open(input_path, "w") as f:
        f.write(json.dumps({"k": 0, "v": "foo"}))
        f.write("\n")
        f.write(json.dumps({"k": 1, "v": "bar"}))
        f.write("\n")
        f.write(json.dumps({"k": 2, "v": "baz"}))
        f.write("\n")

    class InputSchema(pw.Schema):
        k: int = pw.column_definition(primary_key=True)
        v: str

    table = pw.io.jsonlines.read(input_path, schema=InputSchema, mode="static")
    pw.io.kafka.write(
        table,
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic_name=kafka_context.output_topic,
        format="json",
        key=table["v"],
        headers=[table["k"], table["v"]],
    )
    pw.run()
    output_topic_contents = kafka_context.read_output_topic()
    for message in output_topic_contents:
        key = message.key
        value = json.loads(message.value)
        assert value["v"].encode("utf-8") == key
        assert "k" in value
        headers = {}
        for header_key, header_value in message.headers:
            headers[header_key] = header_value
        assert headers["k"] == str(value["k"]).encode("utf-8")
        assert headers["v"] == value["v"].encode("utf-8")


@pytest.mark.parametrize("output_format", ["json", "plaintext"])
@pytest.mark.flaky(reruns=3)
def test_kafka_dynamic_topics(tmp_path, kafka_context, output_format):
    input_path = tmp_path / "input.jsonl"
    dynamic_topic_1 = str(uuid.uuid4())
    dynamic_topic_2 = str(uuid.uuid4())
    kafka_context._create_topic(f"KafkaTopic.{dynamic_topic_1}")
    kafka_context._create_topic(f"KafkaTopic.{dynamic_topic_2}")
    with open(input_path, "w") as f:
        f.write(json.dumps({"k": "0", "v": "foo", "t": dynamic_topic_1}))
        f.write("\n")
        f.write(json.dumps({"k": "1", "v": "bar", "t": dynamic_topic_2}))
        f.write("\n")
        f.write(json.dumps({"k": "2", "v": "baz", "t": dynamic_topic_1}))
        f.write("\n")

    class InputSchema(pw.Schema):
        k: str = pw.column_definition(primary_key=True)
        v: str
        t: str

    table = pw.io.jsonlines.read(input_path, schema=InputSchema, mode="static")
    if output_format == "json":
        write_kwargs = {"format": "json"}
    elif output_format == "plaintext":
        write_kwargs = {
            "format": "plaintext",
            "key": table.k,
            "value": table.v,
        }
    else:
        raise RuntimeError(f"Unknown output format: {output_format}")
    pw.io.kafka.write(
        table,
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic_name=table.select(topic="KafkaTopic." + pw.this.t)["topic"],
        **write_kwargs,
    )
    pw.run()

    def check_keys_in_topic(topic_name, expected_keys):
        keys = set()
        output_topic_contents = kafka_context.read_topic(topic_name)
        for message in output_topic_contents:
            if output_format == "json":
                value = json.loads(message.value)
                keys.add(value["k"])
                assert value.keys() == {"k", "v", "t", "time", "diff", "topic"}
            else:
                keys.add(message.key.decode("utf-8"))
        assert keys == expected_keys

    check_keys_in_topic(f"KafkaTopic.{dynamic_topic_1}", {"0", "2"})
    check_keys_in_topic(f"KafkaTopic.{dynamic_topic_2}", {"1"})


@pytest.mark.flaky(reruns=3)
def test_kafka_registry(tmp_path, kafka_context):
    schema_subject = create_schema_in_registry(
        column_types={
            "key": "integer",
            "value": "string",
            "time": "integer",
            "diff": "integer",
        },
        required_columns=["key", "value", "time", "diff"],
    )

    input_path = tmp_path / "input.jsonl"
    output_path = tmp_path / "output.jsonl"
    raw_output_path = tmp_path / "output_raw.jsonl"
    input_entries = [
        {"key": 1, "value": "one"},
        {"key": 2, "value": "two"},
    ]

    with open(input_path, "w") as f:
        for entry in input_entries:
            f.write(json.dumps(entry))
            f.write("\n")

    class TableSchema(pw.Schema):
        key: int
        value: str

    table = pw.io.jsonlines.read(input_path, schema=TableSchema)

    pw.io.kafka.write(
        table,
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic_name=kafka_context.input_topic,
        format="json",
        schema_registry_settings=pw.io.kafka.SchemaRegistrySettings(
            urls=[SCHEMA_REGISTRY_BASE_ROUTE],
            timeout=datetime.timedelta(seconds=5),
        ),
        subject=schema_subject,
    )

    table_reread = pw.io.kafka.read(
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic=kafka_context.input_topic,
        format="json",
        schema=TableSchema,
        autocommit_duration_ms=100,
        schema_registry_settings=pw.io.kafka.SchemaRegistrySettings(
            urls=[SCHEMA_REGISTRY_BASE_ROUTE],
            timeout=datetime.timedelta(seconds=5),
        ),
    )
    table_raw = pw.io.kafka.read(
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic=kafka_context.input_topic,
        format="raw",
    )

    pw.io.jsonlines.write(table_reread, output_path)
    pw.io.jsonlines.write(table_raw, raw_output_path)
    wait_result_with_checker(
        FileLinesNumberChecker(output_path, 2).add_path(raw_output_path, 2), 30
    )
    output_entries = []
    with open(output_path, "r") as f:
        for line in f:
            data = json.loads(line)
            output_entries.append(
                {
                    "key": data["key"],
                    "value": data["value"],
                }
            )
    output_entries.sort(key=lambda x: x["key"])
    assert output_entries == input_entries

    # Send the data encoded by the registry as a key, while keeping the value as empty.
    # Check that value parsing works.
    additional_topic = kafka_context.create_additional_topic()
    with open(raw_output_path, "r") as f:
        for line in f:
            data = json.loads(line)
            encoded_message = base64.b64decode(data["data"])
            kafka_context.send(message=(encoded_message, None), topic=additional_topic)

    class KeyTableSchema(pw.Schema):
        key: int = pw.column_definition(source_component="key")
        value: str = pw.column_definition(source_component="key")

    G.clear()
    table = pw.io.kafka.read(
        rdkafka_settings=kafka_context.default_rdkafka_settings(),
        topic=additional_topic,
        format="json",
        schema=KeyTableSchema,
        schema_registry_settings=pw.io.kafka.SchemaRegistrySettings(
            urls=[SCHEMA_REGISTRY_BASE_ROUTE],
            timeout=datetime.timedelta(seconds=5),
        ),
        mode="static",
    )
    pw.io.jsonlines.write(table, output_path)
    pw.run(monitoring_level=pw.MonitoringLevel.NONE)

    roundtrip_entries = []
    with open(output_path, "r") as f:
        for line in f:
            data = json.loads(line)
            roundtrip_entries.append(
                {
                    "key": data["key"],
                    "value": data["value"],
                }
            )
    roundtrip_entries.sort(key=lambda x: x["key"])
    assert roundtrip_entries == input_entries
