import re
import sys
import threading

import pytest

import ray
from ray.exceptions import RayActorError, RayTaskError, UnserializableException

"""This module tests stacktrace of Ray.

There are total 3 different stacktrace types in Ray.

1. Not nested task (including actor creation) or actor task failure.
2. Chained task + actor task failure.
3. Dependency failure (upstreamed dependency raises an exception).

There are important factors.
- The root cause of the failure should be printed at the bottom.
- Ray-related code shouldn't be printed at all to the user-level stacktrace.
- It should be easy to follow stacktrace.

Each of test verifies that there's no regression by comparing the line number.
If we include unnecessary stacktrace (e.g., logs from internal files),
these tests will fail.
"""


def scrub_traceback(ex):
    assert isinstance(ex, str)
    print(ex)
    ex = ex.strip("\n")
    ex = re.sub("pid=[0-9]+,", "pid=XXX,", ex)
    ex = re.sub(r"ip=[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+", "ip=YYY", ex)
    ex = re.sub(r"repr=.*\)", "repr=ZZZ)", ex)
    ex = re.sub("line .*,", "line ZZ,", ex)
    ex = re.sub('".*"', '"FILE"', ex)
    # These are used to coloring the string.
    ex = re.sub(r"\x1b\[36m", "", ex)
    ex = re.sub(r"\x1b\[39m", "", ex)
    # When running bazel test with pytest 6.x, the module name becomes
    # "python.ray.tests.test_traceback" instead of just "test_traceback"
    # Also remove the "io_ray" prefix, which may appear on Windows.
    ex = re.sub(
        r"(io_ray.)?python\.ray\.tests\.test_traceback",
        "test_traceback",
        ex,
    )
    # Clean up object address.
    ex = re.sub("object at .*?>", "object at ADDRESS>", ex)
    # This is from ray.util.inspect_serializability()
    ex = re.sub(
        r"=[\s\S]*Checking Serializability of[\s\S]*=", "INSPECT_SERIALIZABILITY", ex
    )
    # Clean up underscore in stack trace, which is new in python 3.12
    ex = re.sub("^\\s+~*\\^+~*\n", "", ex, flags=re.MULTILINE)
    # Remove internal Cython frames from ray._raylet that can appear on Windows.
    ex = re.sub(
        r"^\s*File \"FILE\", line ZZ, in ray\._raylet\.[^\n]+\n",
        "",
        ex,
        flags=re.MULTILINE,
    )
    return ex


def clean_noqa(ex):
    assert isinstance(ex, str)
    # noqa is required to ignore lint, so we just remove it.
    ex = re.sub(" # noqa", "", ex)
    return ex


@pytest.mark.skipif(
    sys.platform == "win32", reason="Clean stacktrace not supported on Windows"
)
def test_actor_creation_stacktrace(ray_start_regular):
    """Test the actor creation task stacktrace."""
    expected_output = """The actor died because of an error raised in its creation task, ray::A.__init__() (pid=XXX, ip=YYY, actor_id={actor_id}, repr=ZZZ) # noqa
  File "FILE", line ZZ, in __init__
    g(3)
  File "FILE", line ZZ, in g
    raise ValueError(a)
ValueError: 3"""

    def g(a):
        raise ValueError(a)

    @ray.remote
    class A:
        def __init__(self):
            g(3)

        def ping(self):
            pass

    a = A.remote()
    try:
        ray.get(a.ping.remote())
    except RayActorError as ex:
        print(ex)
        assert clean_noqa(
            expected_output.format(actor_id=a._actor_id.hex())
        ) == scrub_traceback(str(ex))


@pytest.mark.skipif(
    sys.platform == "win32", reason="Clean stacktrace not supported on Windows"
)
def test_task_stacktrace(ray_start_regular):
    """Test the normal task stacktrace."""
    expected_output = """ray::f() (pid=XXX, ip=YYY)
  File "FILE", line ZZ, in f
    return g(c)
  File "FILE", line ZZ, in g
    raise ValueError(a)
ValueError: 7"""

    def g(a):
        raise ValueError(a)
        # pass

    @ray.remote
    def f():
        a = 3
        b = 4
        c = a + b
        return g(c)

    try:
        ray.get(f.remote())
    except ValueError as ex:
        print(ex)
        assert clean_noqa(expected_output) == scrub_traceback(str(ex))


@pytest.mark.skipif(
    sys.platform == "win32", reason="Clean stacktrace not supported on Windows"
)
def test_actor_task_stacktrace(ray_start_regular):
    """Test the actor task stacktrace."""
    expected_output = """ray::A.f() (pid=XXX, ip=YYY, actor_id={actor_id}, repr=ZZZ) # noqa
  File "FILE", line ZZ, in f
    return g(c)
  File "FILE", line ZZ, in g
    raise ValueError(a)
ValueError: 7"""

    def g(a):
        raise ValueError(a)

    @ray.remote
    class A:
        def f(self):
            a = 3
            b = 4
            c = a + b
            return g(c)

    a = A.remote()
    try:
        ray.get(a.f.remote())
    except ValueError as ex:
        print(ex)
        assert clean_noqa(
            expected_output.format(actor_id=a._actor_id.hex())
        ) == scrub_traceback(str(ex))


@pytest.mark.skipif(
    sys.platform == "win32", reason="Clean stacktrace not supported on Windows"
)
def test_exception_chain(ray_start_regular):
    """Test the chained stacktrace."""
    expected_output = """ray::foo() (pid=XXX, ip=YYY) # noqa
  File "FILE", line ZZ, in foo
    return ray.get(bar.remote())
ray.exceptions.RayTaskError(ZeroDivisionError): ray::bar() (pid=XXX, ip=YYY)
  File "FILE", line ZZ, in bar
    return 1 / 0
ZeroDivisionError: division by zero"""

    @ray.remote
    def bar():
        return 1 / 0

    @ray.remote
    def foo():
        return ray.get(bar.remote())

    r = foo.remote()
    try:
        ray.get(r)
    except ZeroDivisionError as ex:
        assert isinstance(ex, RayTaskError)
        print(ex)
        assert clean_noqa(expected_output) == scrub_traceback(str(ex))


@pytest.mark.skipif(
    sys.platform == "win32", reason="Clean stacktrace not supported on Windows"
)
def test_dep_failure(ray_start_regular):
    """Test the stacktrace genereated due to dependency failures."""
    expected_output = """ray::f() (pid=XXX, ip=YYY) # noqa
  At least one of the input arguments for this task could not be computed:
ray.exceptions.RayTaskError: ray::a() (pid=XXX, ip=YYY)
  At least one of the input arguments for this task could not be computed:
ray.exceptions.RayTaskError: ray::b() (pid=XXX, ip=YYY)
  File "FILE", line ZZ, in b
    raise ValueError("FILE")
ValueError: b failed"""

    @ray.remote
    def f(a, b):
        pass

    @ray.remote
    def a(d):
        pass

    @ray.remote
    def b():
        raise ValueError("b failed")

    try:
        ray.get(f.remote(a.remote(b.remote()), b.remote()))
    except Exception as ex:
        print(ex)
        from pprint import pprint

        pprint(clean_noqa(expected_output))
        pprint(scrub_traceback(str(ex)))
        assert clean_noqa(expected_output) == scrub_traceback(str(ex))


@pytest.mark.skipif(
    sys.platform == "win32", reason="Clean stacktrace not supported on Windows"
)
def test_actor_repr_in_traceback(ray_start_regular):
    def parse_labels_from_traceback(ex):
        error_msg = str(ex)
        error_lines = error_msg.split("\n")
        traceback_line = error_lines[0]
        unformatted_labels = traceback_line.split("(")[2].split(", ")
        label_dict = {}
        for label in unformatted_labels:
            # Remove parenthesis if included.
            if label.startswith("("):
                label = label[1:]
            elif label.endswith(")"):
                label = label[:-1]
            key, value = label.split("=", 1)
            label_dict[key] = value
        return label_dict

    # Test the default repr is Actor(repr=[class_name])
    def g(a):
        raise ValueError(a)

    @ray.remote
    class A:
        def f(self):
            a = 3
            b = 4
            c = a + b
            return g(c)

        def get_repr(self):
            return repr(self)

    a = A.remote()
    try:
        ray.get(a.f.remote())
    except ValueError as ex:
        print(ex)
        label_dict = parse_labels_from_traceback(ex)
        assert label_dict["repr"] == ray.get(a.get_repr.remote())

    # Test if the repr is properly overwritten.
    actor_repr = "ABC"

    @ray.remote
    class A:
        def f(self):
            a = 3
            b = 4
            c = a + b
            return g(c)

        def __repr__(self):
            return actor_repr

    a = A.remote()
    try:
        ray.get(a.f.remote())
    except ValueError as ex:
        print(ex)
        label_dict = parse_labels_from_traceback(ex)
        assert label_dict["repr"] == actor_repr


def test_unpickleable_stacktrace(shutdown_only):
    expected_output = """Failed to deserialize exception. Refer to https://docs.ray.io/en/latest/ray-core/objects/serialization.html#custom-serializers-for-exceptions for more information.
Original exception:
ray.exceptions.RayTaskError: ray::f() (pid=XXX, ip=YYY)
  File "FILE", line ZZ, in f
    return g(c)
  File "FILE", line ZZ, in g
    raise NoPickleError("FILE")
test_traceback.NoPickleError"""

    class NoPickleError(OSError):
        def __init__(self, arg):
            pass

    def g(a):
        raise NoPickleError("asdf")

    @ray.remote
    def f():
        a = 3
        b = 4
        c = a + b
        return g(c)

    with pytest.raises(UnserializableException) as excinfo:
        ray.get(f.remote())

    assert clean_noqa(expected_output) == scrub_traceback(str(excinfo.value))


def test_exception_with_registered_serializer(shutdown_only):
    class NoPickleError(OSError):
        def __init__(self, msg):
            self.msg = msg

        def __str__(self):
            return f"message: {self.msg}"

    def _serializer(e: NoPickleError):
        return {"msg": e.msg}

    def _deserializer(state):
        return NoPickleError(state["msg"] + " deserialized")

    @ray.remote
    def raise_custom_exception():
        ray.util.register_serializer(
            NoPickleError, serializer=_serializer, deserializer=_deserializer
        )
        raise NoPickleError("message")

    try:
        with pytest.raises(NoPickleError) as exc_info:
            ray.get(raise_custom_exception.remote())

        # Ensure dual-typed exception and message propagation
        assert isinstance(exc_info.value, RayTaskError)
        # if custom serializer was not registered, this would be an instance of UnserializableException()
        assert isinstance(exc_info.value, NoPickleError)
        assert "message" in str(exc_info.value)
        # modified message should not be in the exception string, only in the cause
        assert "deserialized" not in str(exc_info.value)
        assert "message deserialized" in str(exc_info.value.cause)
    finally:
        ray.util.deregister_serializer(NoPickleError)


def test_task_error_with_read_only_args_property(ray_start_regular):
    class ReadOnlyArgsError(Exception):
        def __init__(self, msg):
            self._msg = msg

        @property
        def args(self):
            return (self._msg,)

        def __str__(self):
            return self._msg

    @ray.remote
    def raise_read_only_args():
        raise ReadOnlyArgsError("boom")

    with pytest.raises(ReadOnlyArgsError) as exc_info:
        ray.get(raise_read_only_args.remote())

    assert isinstance(exc_info.value, RayTaskError)
    assert isinstance(exc_info.value, ReadOnlyArgsError)  # verify dual inheritance
    assert exc_info.value.args == (exc_info.value.cause,)  # verify args property works
    assert "boom" in str(exc_info.value)


def test_serialization_error_message(shutdown_only):
    expected_output_ray_put = """Could not serialize the put value <unlocked _thread.lock object at ADDRESS>:\nINSPECT_SERIALIZABILITY"""  # noqa
    expected_output_task = """Could not serialize the argument <unlocked _thread.lock object at ADDRESS> for a task or actor test_traceback.test_serialization_error_message.<locals>.task_with_unserializable_arg:\nINSPECT_SERIALIZABILITY"""  # noqa
    expected_output_actor = """Could not serialize the argument <unlocked _thread.lock object at ADDRESS> for a task or actor test_traceback.test_serialization_error_message.<locals>.A.__init__:\nINSPECT_SERIALIZABILITY"""  # noqa
    expected_capture_output_task = """Could not serialize the function test_traceback.test_serialization_error_message.<locals>.capture_lock:\nINSPECT_SERIALIZABILITY"""  # noqa
    expected_capture_output_actor = """Could not serialize the actor class test_traceback.test_serialization_error_message.<locals>.B.__init__:\nINSPECT_SERIALIZABILITY"""  # noqa
    ray.init(num_cpus=1)
    lock = threading.Lock()

    @ray.remote
    def task_with_unserializable_arg(lock):
        print(lock)

    @ray.remote
    class A:
        def __init__(self, lock):
            print(lock)

    @ray.remote
    def capture_lock():
        print(lock)

    @ray.remote
    class B:
        def __init__(self):
            print(lock)

    """
    Test ray.put() an unserializable object.
    """
    with pytest.raises(TypeError) as excinfo:
        ray.put(lock)

    assert clean_noqa(expected_output_ray_put) == scrub_traceback(str(excinfo.value))
    """
    Test a task with an unserializable object.
    """
    with pytest.raises(TypeError) as excinfo:
        task_with_unserializable_arg.remote(lock)

    assert clean_noqa(expected_output_task) == scrub_traceback(str(excinfo.value))
    """
    Test an actor with an unserializable object.
    """
    with pytest.raises(TypeError) as excinfo:
        a = A.remote(lock)
        print(a)
    assert clean_noqa(expected_output_actor) == scrub_traceback(str(excinfo.value))
    """
    Test the case where an unserializable object is captured by tasks.
    """
    with pytest.raises(TypeError) as excinfo:
        capture_lock.remote()
    assert clean_noqa(expected_capture_output_task) == scrub_traceback(
        str(excinfo.value)
    )
    """
    Test the case where an unserializable object is captured by actors.
    """
    with pytest.raises(TypeError) as excinfo:
        b = B.remote()
        print(b)
    assert clean_noqa(expected_capture_output_actor) == scrub_traceback(
        str(excinfo.value)
    )


if __name__ == "__main__":
    sys.exit(pytest.main(["-sv", __file__]))
