from ray.includes.common cimport (
    CGcsClientOptions,
    CGcsNodeState,
    PythonGetResourcesTotal,
    PythonGetNodeLabels
)

from ray.includes.unique_ids cimport (
    CActorID,
    CNodeID,
    CObjectID,
    CWorkerID,
    CPlacementGroupID
)

from ray.includes.global_state_accessor cimport (
    CGlobalStateAccessor,
    RedisDelKeyPrefixSync,
)

from ray.includes.optional cimport (
    optional,
    nullopt,
    make_optional
)

from libc.stdint cimport uint32_t as c_uint32_t, int32_t as c_int32_t
from libcpp.string cimport string as c_string
from libcpp.memory cimport make_unique as c_make_unique

cdef class GlobalStateAccessor:
    """Cython wrapper class of C++ `ray::gcs::GlobalStateAccessor`."""
    cdef:
        unique_ptr[CGlobalStateAccessor] inner

    def __cinit__(self, GcsClientOptions gcs_options):
        cdef CGcsClientOptions *opts
        opts = gcs_options.native()
        self.inner = c_make_unique[CGlobalStateAccessor](opts[0])

    def connect(self):
        cdef c_bool result
        with nogil:
            result = self.inner.get().Connect()
        return result

    def get_job_table(
        self, *, skip_submission_job_info_field=False, skip_is_running_tasks_field=False
    ):
        cdef c_vector[c_string] result
        cdef c_bool c_skip_submission_job_info_field = skip_submission_job_info_field
        cdef c_bool c_skip_is_running_tasks_field = skip_is_running_tasks_field

        with nogil:
            result = self.inner.get().GetAllJobInfo(
                c_skip_submission_job_info_field, c_skip_is_running_tasks_field)
        return result

    def get_next_job_id(self):
        cdef CJobID cjob_id
        with nogil:
            cjob_id = self.inner.get().GetNextJobID()
        return cjob_id.ToInt()

    def get_node_table(self):
        cdef:
            c_vector[c_string] items
            c_string item
            CGcsNodeInfo c_node_info
            unordered_map[c_string, double] c_resources
        with nogil:
            items = self.inner.get().GetAllNodeInfo()
        results = []
        for item in items:
            c_node_info.ParseFromString(item)
            node_info = {
                "NodeID": ray._common.utils.binary_to_hex(c_node_info.node_id()),
                "Alive": c_node_info.state() == CGcsNodeState.ALIVE,
                "NodeManagerAddress": c_node_info.node_manager_address().decode(),
                "NodeManagerHostname": c_node_info.node_manager_hostname().decode(),
                "NodeManagerPort": c_node_info.node_manager_port(),
                "ObjectManagerPort": c_node_info.object_manager_port(),
                "ObjectStoreSocketName":
                    c_node_info.object_store_socket_name().decode(),
                "RayletSocketName": c_node_info.raylet_socket_name().decode(),
                "MetricsExportPort": c_node_info.metrics_export_port(),
                "MetricsAgentPort": c_node_info.metrics_agent_port(),
                "DashboardAgentListenPort": c_node_info.dashboard_agent_listen_port(),
                "NodeName": c_node_info.node_name().decode(),
                "RuntimeEnvAgentPort": c_node_info.runtime_env_agent_port(),
                "DeathReason": c_node_info.death_info().reason(),
                "DeathReasonMessage":
                    c_node_info.death_info().reason_message().decode(),
            }
            node_info["alive"] = node_info["Alive"]
            c_resources = PythonGetResourcesTotal(c_node_info)
            node_info["Resources"] = (
                {key.decode(): value for key, value in c_resources}
                if node_info["Alive"]
                else {}
            )
            c_labels = PythonGetNodeLabels(c_node_info)
            node_info["Labels"] = \
                {key.decode(): value.decode() for key, value in c_labels}
            results.append(node_info)
        return results

    def get_draining_nodes(self):
        cdef:
            unordered_map[CNodeID, int64_t] draining_nodes
            unordered_map[CNodeID, int64_t].iterator draining_nodes_it

        with nogil:
            draining_nodes = self.inner.get().GetDrainingNodes()
        draining_nodes_it = draining_nodes.begin()
        results = {}
        while draining_nodes_it != draining_nodes.end():
            draining_node_id = dereference(draining_nodes_it).first
            results[ray._common.utils.binary_to_hex(
                draining_node_id.Binary())] = dereference(draining_nodes_it).second
            postincrement(draining_nodes_it)

        return results

    def get_internal_kv(self, namespace, key):
        cdef:
            c_string c_namespace = namespace
            c_string c_key = key
            unique_ptr[c_string] result
        with nogil:
            result = self.inner.get().GetInternalKV(c_namespace, c_key)
        if result:
            return c_string(result.get().data(), result.get().size())
        return None

    def get_all_available_resources(self):
        cdef c_vector[c_string] result
        with nogil:
            result = self.inner.get().GetAllAvailableResources()
        return result

    def get_all_total_resources(self):
        cdef c_vector[c_string] result
        with nogil:
            result = self.inner.get().GetAllTotalResources()
        return result

    def get_task_events(self):
        cdef c_vector[c_string] result
        with nogil:
            result = self.inner.get().GetAllTaskEvents()
        return result

    def get_all_resource_usage(self):
        """Get newest resource usage of all nodes from GCS service."""
        cdef unique_ptr[c_string] result
        with nogil:
            result = self.inner.get().GetAllResourceUsage()
        if result:
            return c_string(result.get().data(), result.get().size())
        return None

    def get_actor_table(self, job_id, actor_state_name):
        cdef c_vector[c_string] result
        cdef optional[CActorID] cactor_id = nullopt
        cdef optional[CJobID] cjob_id
        cdef optional[c_string] cactor_state_name
        cdef c_string c_name
        if job_id is not None:
            cjob_id = make_optional[CJobID](CJobID.FromBinary(job_id.binary()))
        if actor_state_name is not None:
            c_name = actor_state_name
            cactor_state_name = make_optional[c_string](c_name)
        with nogil:
            result = self.inner.get().GetAllActorInfo(
                cactor_id, cjob_id, cactor_state_name)
        return result

    def get_actor_info(self, actor_id):
        cdef unique_ptr[c_string] actor_info
        cdef CActorID cactor_id = CActorID.FromBinary(actor_id.binary())
        with nogil:
            actor_info = self.inner.get().GetActorInfo(cactor_id)
        if actor_info:
            return c_string(actor_info.get().data(), actor_info.get().size())
        return None

    def get_worker_table(self):
        cdef c_vector[c_string] result
        with nogil:
            result = self.inner.get().GetAllWorkerInfo()
        return result

    def get_worker_info(self, worker_id):
        cdef unique_ptr[c_string] worker_info
        cdef CWorkerID cworker_id = <CWorkerID>CUniqueID.FromBinary(worker_id.binary())
        with nogil:
            worker_info = self.inner.get().GetWorkerInfo(cworker_id)
        if worker_info:
            return c_string(worker_info.get().data(), worker_info.get().size())
        return None

    def add_worker_info(self, serialized_string):
        cdef c_bool result
        cdef c_string cserialized_string = serialized_string
        with nogil:
            result = self.inner.get().AddWorkerInfo(cserialized_string)
        return result

    def get_worker_debugger_port(self, worker_id):
        cdef c_uint32_t result
        cdef CWorkerID cworker_id = <CWorkerID>CUniqueID.FromBinary(worker_id.binary())
        with nogil:
            result = self.inner.get().GetWorkerDebuggerPort(cworker_id)
        return result

    def update_worker_debugger_port(self, worker_id, debugger_port):
        cdef c_bool result
        cdef CWorkerID cworker_id = <CWorkerID>CUniqueID.FromBinary(worker_id.binary())
        cdef c_uint32_t cdebugger_port = debugger_port
        with nogil:
            result = self.inner.get().UpdateWorkerDebuggerPort(
                cworker_id,
                cdebugger_port)
        return result

    def update_worker_num_paused_threads(self, worker_id, num_paused_threads_delta):
        cdef c_bool result
        cdef CWorkerID cworker_id = <CWorkerID>CUniqueID.FromBinary(worker_id.binary())
        cdef c_int32_t cnum_paused_threads_delta = num_paused_threads_delta

        with nogil:
            result = self.inner.get().UpdateWorkerNumPausedThreads(
                cworker_id, cnum_paused_threads_delta)
        return result

    def get_placement_group_table(self):
        cdef c_vector[c_string] result
        with nogil:
            result = self.inner.get().GetAllPlacementGroupInfo()
        return result

    def get_placement_group_info(self, placement_group_id):
        cdef unique_ptr[c_string] result
        cdef CPlacementGroupID cplacement_group_id = (
            CPlacementGroupID.FromBinary(placement_group_id.binary()))
        with nogil:
            result = self.inner.get().GetPlacementGroupInfo(
                cplacement_group_id)
        if result:
            return c_string(result.get().data(), result.get().size())
        return None

    def get_placement_group_by_name(self, placement_group_name, ray_namespace):
        cdef unique_ptr[c_string] result
        cdef c_string cplacement_group_name = placement_group_name
        cdef c_string cray_namespace = ray_namespace
        with nogil:
            result = self.inner.get().GetPlacementGroupByName(
                cplacement_group_name, cray_namespace)
        if result:
            return c_string(result.get().data(), result.get().size())
        return None

    def get_system_config(self):
        return self.inner.get().GetSystemConfig()

    def get_node(self, node_id):
        cdef CRayStatus status
        cdef c_string cnode_id = node_id
        cdef c_string cnode_info_str
        cdef CGcsNodeInfo c_node_info
        with nogil:
            status = self.inner.get().GetNode(cnode_id, &cnode_info_str)
        if not status.ok():
            raise RuntimeError(status.message())
        c_node_info.ParseFromString(cnode_info_str)
        c_labels = PythonGetNodeLabels(c_node_info)
        return {
            "object_store_socket_name": c_node_info.object_store_socket_name().decode(),
            "raylet_socket_name": c_node_info.raylet_socket_name().decode(),
            "node_manager_port": c_node_info.node_manager_port(),
            "node_id": c_node_info.node_id().hex(),
            "runtime_env_agent_port": c_node_info.runtime_env_agent_port(),
            "metrics_agent_port": c_node_info.metrics_agent_port(),
            "metrics_export_port": c_node_info.metrics_export_port(),
            "dashboard_agent_listen_port": c_node_info.dashboard_agent_listen_port(),
            "labels": {key.decode(): value.decode() for key, value in c_labels},
        }
