import sys
from collections import defaultdict
from typing import Dict

import pytest

import ray
from ray._common.test_utils import wait_for_condition
from ray._private.test_utils import (
    PrometheusTimeseries,
    raw_metric_timeseries,
)
from ray._private.worker import RayContext
from ray.util.placement_group import remove_placement_group

_SYSTEM_CONFIG = {
    "metrics_report_interval_ms": 200,
}


def groups_by_state(info: RayContext, timeseries: PrometheusTimeseries) -> Dict:
    res = raw_metric_timeseries(info, timeseries)
    info = defaultdict(int)
    if "ray_placement_groups" in res:
        for sample in res["ray_placement_groups"]:
            info[sample.labels["State"]] += sample.value
    for k, v in info.copy().items():
        if v == 0:
            del info[k]
    print(f"Groups by state: {info}")
    return info


def test_basic_states(shutdown_only):
    info = ray.init(num_cpus=3, _system_config=_SYSTEM_CONFIG)

    pg1 = ray.util.placement_group(bundles=[{"CPU": 1}])
    pg2 = ray.util.placement_group(bundles=[{"CPU": 1}])
    pg3 = ray.util.placement_group(bundles=[{"CPU": 4}])
    ray.get([pg1.ready(), pg2.ready()])

    expected = {
        "CREATED": 2,
        "PENDING": 1,
    }
    timeseries = PrometheusTimeseries()
    wait_for_condition(
        lambda: groups_by_state(info, timeseries) == expected,
        timeout=20,
        retry_interval_ms=500,
    )

    remove_placement_group(pg1)
    remove_placement_group(pg2)
    remove_placement_group(pg3)

    expected = {
        "REMOVED": 3,
    }
    timeseries = PrometheusTimeseries()
    wait_for_condition(
        lambda: groups_by_state(info, timeseries) == expected,
        timeout=20,
        retry_interval_ms=500,
    )


if __name__ == "__main__":
    sys.exit(pytest.main(["-sv", __file__]))
