# coding: utf-8
import os
import shutil
import sys
import tempfile
import unittest

import numpy as np
import pandas
import pytest
from hyperopt import hp
from nevergrad.optimization import optimizerlib
from packaging.version import Version
from zoopt import ValueType

import ray
from ray import tune
from ray.rllib import _register_all
from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
from ray.tune.search import ConcurrencyLimiter
from ray.tune.search.ax import AxSearch
from ray.tune.search.bayesopt import BayesOptSearch
from ray.tune.search.bohb import TuneBOHB
from ray.tune.search.hebo import HEBOSearch
from ray.tune.search.hyperopt import HyperOptSearch
from ray.tune.search.nevergrad import NevergradSearch
from ray.tune.search.optuna import OptunaSearch
from ray.tune.search.zoopt import ZOOptSearch


class AbstractWarmStartTest:
    def setUp(self):
        ray.init(num_cpus=1)
        self.tmpdir = tempfile.mkdtemp()
        self.experiment_name = "results"

    def tearDown(self):
        shutil.rmtree(self.tmpdir)
        ray.shutdown()
        _register_all()

    def set_basic_conf(self):
        raise NotImplementedError()

    def get_scheduler(self):
        return None

    def treat_trial_config(self, trial_config):
        return trial_config

    def run_part_from_scratch(self):
        np.random.seed(162)
        search_alg, cost = self.set_basic_conf()
        if not isinstance(search_alg, ConcurrencyLimiter):
            search_alg = ConcurrencyLimiter(search_alg, 1)
        results_exp_1 = tune.run(
            cost,
            num_samples=5,
            search_alg=search_alg,
            scheduler=self.get_scheduler(),
            verbose=0,
            name=self.experiment_name,
            storage_path=self.tmpdir,
            reuse_actors=True,
        )
        checkpoint_path = os.path.join(self.tmpdir, "warmStartTest.pkl")
        search_alg.save(checkpoint_path)
        return results_exp_1, np.random.get_state(), checkpoint_path

    def run_from_experiment_restore(self, random_state):
        search_alg, cost = self.set_basic_conf()
        if not isinstance(search_alg, ConcurrencyLimiter):
            search_alg = ConcurrencyLimiter(search_alg, 1)
        search_alg.restore_from_dir(os.path.join(self.tmpdir, self.experiment_name))
        results = tune.run(
            cost,
            num_samples=5,
            search_alg=search_alg,
            scheduler=self.get_scheduler(),
            verbose=0,
            name=self.experiment_name,
            storage_path=self.tmpdir,
            reuse_actors=True,
        )
        return results

    def run_explicit_restore(self, random_state, checkpoint_path):
        np.random.set_state(random_state)
        search_alg2, cost = self.set_basic_conf()
        if not isinstance(search_alg2, ConcurrencyLimiter):
            search_alg2 = ConcurrencyLimiter(search_alg2, 1)
        search_alg2.restore(checkpoint_path)
        return tune.run(
            cost,
            num_samples=5,
            search_alg=search_alg2,
            scheduler=self.get_scheduler(),
            verbose=0,
            reuse_actors=True,
        )

    def run_full(self):
        np.random.seed(162)
        search_alg3, cost = self.set_basic_conf()
        if not isinstance(search_alg3, ConcurrencyLimiter):
            search_alg3 = ConcurrencyLimiter(search_alg3, 1)
        return tune.run(
            cost,
            num_samples=10,
            search_alg=search_alg3,
            scheduler=self.get_scheduler(),
            verbose=0,
            reuse_actors=True,
        )

    def testWarmStart(self):
        results_exp_1, r_state, checkpoint_path = self.run_part_from_scratch()
        results_exp_2 = self.run_explicit_restore(r_state, checkpoint_path)
        results_exp_3 = self.run_full()
        trials_1_config = self.treat_trial_config(
            [trial.config for trial in results_exp_1.trials]
        )
        trials_2_config = self.treat_trial_config(
            [trial.config for trial in results_exp_2.trials]
        )
        trials_3_config = self.treat_trial_config(
            [trial.config for trial in results_exp_3.trials]
        )
        self.assertEqual(trials_1_config + trials_2_config, trials_3_config)

    def testRestore(self):
        results_exp_1, r_state, checkpoint_path = self.run_part_from_scratch()
        results_exp_2 = self.run_from_experiment_restore(r_state)
        results_exp_3 = self.run_full()

        trials_1_config = self.treat_trial_config(
            [trial.config for trial in results_exp_1.trials]
        )
        trials_2_config = self.treat_trial_config(
            [trial.config for trial in results_exp_2.trials]
        )
        trials_3_config = self.treat_trial_config(
            [trial.config for trial in results_exp_3.trials]
        )
        self.assertEqual(trials_1_config + trials_2_config, trials_3_config)


class HyperoptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
    def set_basic_conf(self):
        space = {
            "x": hp.uniform("x", 0, 10),
            "y": hp.uniform("y", -10, 10),
            "z": hp.uniform("z", -10, 0),
        }

        def cost(space):
            loss = space["x"] ** 2 + space["y"] ** 2 + space["z"] ** 2
            tune.report(dict(loss=loss))

        search_alg = HyperOptSearch(
            space,
            metric="loss",
            mode="min",
            random_state_seed=5,
            n_initial_points=1,
        )
        return search_alg, cost


class BayesoptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
    def set_basic_conf(self, analysis=None):
        space = {"width": (0, 20), "height": (-100, 100)}

        def cost(space):
            tune.report(
                dict(loss=(space["height"] - 14) ** 2 - abs(space["width"] - 3))
            )

        search_alg = BayesOptSearch(space, metric="loss", mode="min", analysis=analysis)
        return search_alg, cost

    def testBootStrapAnalysis(self):
        analysis = self.run_full()
        search_alg3, cost = self.set_basic_conf(analysis)
        if not isinstance(search_alg3, ConcurrencyLimiter):
            search_alg3 = ConcurrencyLimiter(search_alg3, 1)
        tune.run(
            cost, num_samples=10, search_alg=search_alg3, verbose=0, reuse_actors=True
        )


class NevergradWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
    def set_basic_conf(self):
        instrumentation = 2
        parameter_names = ["height", "width"]
        optimizer = optimizerlib.OnePlusOne(instrumentation)

        def cost(space):
            tune.report(
                dict(loss=(space["height"] - 14) ** 2 - abs(space["width"] - 3))
            )

        search_alg = NevergradSearch(
            optimizer,
            space=parameter_names,
            metric="loss",
            mode="min",
        )
        return search_alg, cost


class OptunaWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
    def set_basic_conf(self):
        from optuna.samplers import TPESampler

        space = OptunaSearch.convert_search_space(
            {"width": tune.uniform(0, 20), "height": tune.uniform(-100, 100)}
        )

        def cost(space):
            tune.report(
                dict(loss=(space["height"] - 14) ** 2 - abs(space["width"] - 3))
            )

        search_alg = OptunaSearch(
            space, sampler=TPESampler(seed=10), metric="loss", mode="min"
        )
        return search_alg, cost


class ZOOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
    def set_basic_conf(self):
        dim_dict = {
            "height": (ValueType.CONTINUOUS, [-100, 100], 1e-2),
            "width": (ValueType.DISCRETE, [0, 20], False),
        }

        def cost(param):
            tune.report(
                dict(loss=(param["height"] - 14) ** 2 - abs(param["width"] - 3))
            )

        search_alg = ZOOptSearch(
            algo="Asracos",  # only support ASRacos currently
            budget=200,
            dim_dict=dim_dict,
            metric="loss",
            mode="min",
        )

        return search_alg, cost


@pytest.mark.skipif(sys.version_info >= (3, 12), reason="HEBO doesn't support py312")
class HEBOWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
    def set_basic_conf(self):
        if Version(pandas.__version__) >= Version("2.0.0"):
            pytest.skip("HEBO does not support pandas>=2.0.0")

        from hebo.design_space.design_space import DesignSpace as HEBODesignSpace

        space_config = [
            {"name": "width", "type": "num", "lb": 0, "ub": 20},
            {"name": "height", "type": "num", "lb": -100, "ub": 100},
        ]
        space = HEBODesignSpace().parse(space_config)

        def cost(param):
            tune.report(
                dict(loss=(param["height"] - 14) ** 2 - abs(param["width"] - 3))
            )

        search_alg = HEBOSearch(
            space=space, metric="loss", mode="min", random_state_seed=5
        )
        # This is done on purpose to speed up the test, as HEBO will
        # cache suggestions
        search_alg = ConcurrencyLimiter(search_alg, max_concurrent=10)
        return search_alg, cost


class AxWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
    def set_basic_conf(self):
        from ax.service.ax_client import AxClient

        space = AxSearch.convert_search_space(
            {"width": tune.uniform(0, 20), "height": tune.uniform(-100, 100)}
        )

        from ax.modelbridge.generation_strategy import (
            GenerationStep,
            GenerationStrategy,
        )
        from ax.modelbridge.registry import Models

        # set generation strategy to sobol to ensure reproductibility
        try:
            # ax-platform>=0.2.0
            gs = GenerationStrategy(
                steps=[
                    GenerationStep(
                        model=Models.SOBOL,
                        num_trials=-1,
                        model_kwargs={"seed": 4321},
                    ),
                ]
            )
        except TypeError:
            # ax-platform<0.2.0
            gs = GenerationStrategy(
                steps=[
                    GenerationStep(
                        model=Models.SOBOL,
                        num_arms=-1,
                        model_kwargs={"seed": 4321},
                    ),
                ]
            )

        client = AxClient(random_seed=4321, generation_strategy=gs)
        client.create_experiment(parameters=space, objective_name="loss", minimize=True)

        def cost(space):
            tune.report(
                dict(loss=(space["height"] - 14) ** 2 - abs(space["width"] - 3))
            )

        search_alg = AxSearch(ax_client=client)
        return search_alg, cost


@pytest.mark.skipif(sys.version_info >= (3, 12), reason="BOHB doesn't support py312")
class BOHBWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
    def set_basic_conf(self):
        space = {"width": tune.uniform(0, 20), "height": tune.uniform(-100, 100)}

        def cost(space):
            for i in range(10):
                tune.report(
                    dict(loss=(space["height"] - 14) ** 2 - abs(space["width"] - 3 - i))
                )

        search_alg = TuneBOHB(space=space, metric="loss", mode="min", seed=1)

        return search_alg, cost

    def get_scheduler(self):
        return HyperBandForBOHB(max_t=100, metric="loss", mode="min")


if __name__ == "__main__":
    sys.exit(pytest.main(["-v", __file__] + sys.argv[1:]))
