#!/usr/bin/env python3
"""
This script compares benchmark results from two release directories one by one.

Usage:
python3 release/release_logs/compare_perf_metrics <old-dir> <new-dir>
"""

import json
import pathlib
import argparse
import sys


def parse_args():
    parser = argparse.ArgumentParser(
        description="Automate the process of calculating relative change in "
        "perf_metrics. This makes catching regressions much easier."
    )
    parser.add_argument(
        "old_dir_name",
        type=str,
        help="The name of the directory containing the last release "
        "performance logs, e.g. 2.2.0",
    )
    parser.add_argument(
        "new_dir_name",
        type=str,
        help="The name of the directory containing the new release "
        "performance logs, e.g. 2.3.0",
    )
    args = parser.parse_args()
    return args


def main(old_dir_name, new_dir_name):
    old_files = list(walk(old_dir_name))
    new_files = list(walk(new_dir_name))

    old_by_name = group_by_filename(old_files, old_dir_name)
    new_by_name = group_by_filename(new_files, new_dir_name)

    all_filenames = set(old_by_name.keys()) | set(new_by_name.keys())

    throughput_regressions = []
    latency_regressions = []
    missing_in_new = []
    missing_in_old = []

    for filename in sorted(all_filenames):
        old_path = old_by_name.get(filename)
        new_path = new_by_name.get(filename)

        if not old_path:
            print(f"{old_dir_name} is missing {filename}")
            continue
        if not new_path:
            print(f"{new_dir_name} is missing {filename}")
            continue

        # Compare the two files
        throughput, latency, missing_new_metrics, missing_old_metrics = get_regressions(old_path, new_path)

        throughput_regressions.extend(throughput)
        latency_regressions.extend(latency)
        missing_in_new.extend(missing_new_metrics)
        missing_in_old.extend(missing_old_metrics)

    for metric in missing_in_new:
        print(f"{new_path} does not have {metric}")

    for metric in missing_in_old:
        print(f"{old_path} does not have {metric}")

    throughput_regressions.sort()
    for _, regression in throughput_regressions:
        print(regression)

    latency_regressions.sort(reverse=True)
    for _, regression in latency_regressions:
        print(regression)


def walk(dir_name):
    stack = [pathlib.Path(dir_name)]
    while stack:
        root = stack.pop()
        if not root.is_dir():
            yield root
        else:
            stack.extend(root.iterdir())


def group_by_filename(paths, base_dir):
    """
    Return a dict mapping filenames to full paths.
    If there are duplicates, logging warning and ignore later ones.
    """
    file_map = {}
    for path in paths:
        name = path.name
        rel_path = path.relative_to(base_dir)
        if name not in file_map:
            file_map[name] = path
        else:
            print(f"Warning: duplicate filename {name} found at {rel_path}")
    return file_map


def get_compare_list(old, new):
    old_set = set(old)
    new_set = set(new)

    return (
        old_set.intersection(new_set),
        old_set.difference(new_set),
        new_set.difference(old_set),
    )

def get_regressions(old_path, new_path):

    with open(old_path, "r") as f:
        old = json.load(f)

    with open(new_path, "r") as f:
        new = json.load(f)

    def perf_metrics(root):
        return root["perf_metrics"]

    def types(perf_metric):
        return perf_metric["perf_metric_type"]

    def values(perf_metric):
        return perf_metric["perf_metric_value"]

    def names(perf_metric):
        return perf_metric["perf_metric_name"]

    def list_to_dict(input_list, key_selector, value_selector):
        return {key_selector(e): value_selector(e) for e in input_list}

    old_values = list_to_dict(perf_metrics(old), names, values)

    new_values = list_to_dict(perf_metrics(new), names, values)

    perf_metric_types = {
        **list_to_dict(perf_metrics(old), names, types),
        **list_to_dict(perf_metrics(new), names, types),
    }

    to_compare, missing_in_new, missing_in_old = get_compare_list(
        old_values.keys(),
        new_values.keys(),
    )

    throughput_regressions = []
    latency_regression = []
    for perf_metric_name in to_compare:
        perf_type = perf_metric_types[perf_metric_name]
        old_value = old_values[perf_metric_name]
        new_value = new_values[perf_metric_name]

        ratio = new_value / old_value
        ratio_str = f"{100 * abs(ratio - 1):.02f}%"
        regression_message = f"""REGRESSION {ratio_str}: {perf_metric_name} ({perf_type}) regresses from {old_value} to {new_value} ({ratio_str}) in {new_path}"""

        if perf_type == "THROUGHPUT":
            if ratio < 1.0:
                throughput_regressions.append((ratio, regression_message))
        elif perf_type == "LATENCY":
            if ratio > 1.0:
                latency_regression.append((ratio, regression_message))
        else:
            raise ValueError(f"perf_metric_name not of expected type {perf_type}")

    return throughput_regressions, latency_regression, missing_in_new, missing_in_old


if __name__ == "__main__":
    args = parse_args()
    sys.exit(main(args.old_dir_name, args.new_dir_name))
