"""Hyperparameter optimizer using optuna."""
import pickle
import optuna
from mala.common.parameters import printout
from mala.network.hyper_opt import HyperOpt
from mala.network.objective_base import ObjectiveBase
from mala.network.naswot_pruner import NASWOTPruner
from mala.network.multi_training_pruner import MultiTrainingPruner
from mala.common.parallelizer import parallel_warn
[docs]
class HyperOptOptuna(HyperOpt):
    """Hyperparameter optimizer using Optuna.
    Parameters
    ----------
    params : mala.common.parameters.Parameters
        Parameters used to create this hyperparameter optimizer.
    data : mala.datahandling.data_handler.DataHandler
        DataHandler holding the data for the hyperparameter optimization.
    use_pkl_checkpoints : bool
        If true, .pkl checkpoints will be created.
    Attributes
    ----------
    params : mala.common.parameters.Parameters
        MALA Parameters object.
    objective : mala.network.objective_base.ObjectiveBase
        MALA objective to be optimized, i.e., a MALA NN model training.
    study : optuna.study.Study
        An Optuna study used to collect the results of the hyperparameter
        optimization.
    """
    def __init__(self, params, data, use_pkl_checkpoints=False):
        super(HyperOptOptuna, self).__init__(
            params, data, use_pkl_checkpoints=use_pkl_checkpoints
        )
        self.params = params
        # Make the sample behave in a reproducible way, if so specified by
        # the user.
        sampler = optuna.samplers.TPESampler(
            seed=params.manual_seed,
            multivariate=params.hyperparameters.use_multivariate,
        )
        # See if the user specified a pruner.
        pruner = None
        if self.params.hyperparameters.pruner is not None:
            if self.params.hyperparameters.pruner == "naswot":
                pruner = NASWOTPruner(self.params, data)
            elif self.params.hyperparameters.pruner == "multi_training":
                if self.params.hyperparameters.number_training_per_trial > 1:
                    pruner = MultiTrainingPruner(self.params)
                else:
                    printout(
                        "MultiTrainingPruner requested, but only one "
                        "training"
                        "per trial specified; Skipping pruner creation."
                    )
            else:
                raise Exception("Invalid pruner type selected.")
        # Create the study.
        if self.params.hyperparameters.rdb_storage is None:
            self.study = optuna.create_study(
                direction=self.params.hyperparameters.direction,
                sampler=sampler,
                study_name=self.params.hyperparameters.study_name,
                pruner=pruner,
            )
        else:
            if self.params.hyperparameters.study_name is None:
                raise Exception(
                    "If RDB storage is used, a name for the study "
                    "has to be provided."
                )
            if "sqlite" in self.params.hyperparameters.rdb_storage:
                engine_kwargs = {
                    "connect_args": {
                        "timeout": self.params.hyperparameters.sqlite_timeout
                    }
                }
            else:
                engine_kwargs = None
            rdb_storage = optuna.storages.RDBStorage(
                url=self.params.hyperparameters.rdb_storage,
                heartbeat_interval=self.params.hyperparameters.rdb_storage_heartbeat,
                engine_kwargs=engine_kwargs,
            )
            self.study = optuna.create_study(
                direction=self.params.hyperparameters.direction,
                sampler=sampler,
                study_name=self.params.hyperparameters.study_name,
                storage=rdb_storage,
                load_if_exists=True,
                pruner=pruner,
            )
        self._checkpoint_counter = 0
[docs]
    def set_optimal_parameters(self):
        """
        Set the optimal parameters found in the present study.
        The parameters will be written to the parameter object with which the
        hyperparameter optimizer was created.
        """
        # Parse the parameters from the best trial.
        self.objective.parse_trial_optuna(self.study.best_trial) 
[docs]
    def get_trials_from_study(self):
        """
        Return the trials from the last study.
        Only returns completed trials.
        Returns
        -------
        last_trials: list
            A list of optuna.FrozenTrial objects.
        """
        return self.study.get_trials(
            states=(optuna.trial.TrialState.COMPLETE,)
        ) 
[docs]
    @staticmethod
    def requeue_zombie_trials(study_name, rdb_storage):
        """
        Put zombie trials back into the queue to be investigated.
        When using Optuna with scheduling systems in HPC infrastructure,
        zombie trials can occur. These are trials that are still marked
        as "RUNNING", but are, in actuality, dead, since the HPC job ended.
        This function takes a saved hyperparameter study, and puts all
        "RUNNING" trials als "WAITING". Upon the next execution from
        checkpoint, they will be executed.
        BE CAREFUL! DO NOT USE APPLY THIS TO A RUNNING STUDY, IT WILL MESS THE
        STUDY UP! ONLY USE THIS ONCE ALL JOBS HAVE FINISHED, TO CLEAN UP,
        AND THEN RESUBMIT!
        Parameters
        ----------
        rdb_storage : string
            Adress of the RDB storage to be cleaned.
        study_name : string
            Name of the study in the storage. Same as the checkpoint name.
        """
        study_to_clean = optuna.load_study(
            study_name=study_name, storage=rdb_storage
        )
        parallel_warn(
            "WARNING: Your about to clean/requeue a study."
            " This operation should not be done to an already"
            " running study."
        )
        trials = study_to_clean.get_trials()
        cleaned_trials = []
        for trial in trials:
            if trial.state == optuna.trial.TrialState.RUNNING:
                kwds = dict(
                    trial_id=trial._trial_id,
                    state=optuna.trial.TrialState.WAITING,
                )
                if hasattr(study_to_clean._storage, "set_trial_state"):
                    # Optuna 2.x
                    study_to_clean._storage.set_trial_state(**kwds)
                else:
                    # Optuna 3.x
                    study_to_clean._storage.set_trial_state_values(
                        values=None, **kwds
                    )
                cleaned_trials.append(trial.number)
        printout("Cleaned trials: ", cleaned_trials, min_verbosity=0) 
[docs]
    @classmethod
    def resume_checkpoint(
        cls,
        checkpoint_name,
        alternative_storage_path=None,
        no_data=False,
        use_pkl_checkpoints=False,
    ):
        """
        Prepare resumption of hyperparameter optimization from a checkpoint.
        Please note that to actually resume the optimization,
        HyperOptOptuna.perform_study() still has to be called.
        Parameters
        ----------
        checkpoint_name : string
            Name of the checkpoint from which the checkpoint is loaded.
        alternative_storage_path: string
            Alternative storage string to load the study from.
            For applications on an HPC cluster it might be necessary to
            slightly modify the storage path between runs, since the SQL
            server might be running on different nodes each time.
        no_data : bool
            If True, the data won't actually be loaded into RAM or scaled.
            This can be useful for cases where a checkpoint is loaded
            for analysis purposes.
        use_pkl_checkpoints : bool
            If true, .pkl checkpoints will be loaded.
        Returns
        -------
        loaded_params : mala.common.parameters.Parameters
            The Parameters saved in the checkpoint.
        new_datahandler : mala.datahandling.data_handler.DataHandler
            The data handler reconstructed from the checkpoint.
        new_hyperopt : HyperOptOptuna
            The hyperparameter optimizer reconstructed from the checkpoint.
        """
        loaded_params, new_datahandler, optimizer_name = (
            cls._resume_checkpoint(
                checkpoint_name,
                no_data=no_data,
                use_pkl_checkpoints=use_pkl_checkpoints,
            )
        )
        if alternative_storage_path is not None:
            loaded_params.hyperparameters.rdb_storage = (
                alternative_storage_path
            )
        new_hyperopt = HyperOptOptuna.load_from_file(
            loaded_params, optimizer_name, new_datahandler
        )
        return loaded_params, new_datahandler, new_hyperopt 
[docs]
    @classmethod
    def load_from_file(cls, params, file_path, data):
        """
        Load a hyperparameter optimizer from a file.
        Parameters
        ----------
        params : mala.common.parameters.Parameters
            Parameters object with which the hyperparameter optimizer
            should be created Has to be compatible with data.
        file_path : string
            Path to the file from which the hyperparameter optimizer should
            be loaded.
        data : mala.datahandling.data_handler.DataHandler
            DataHandler holding the training data.
        Returns
        -------
        loaded_trainer : Network
            The hyperparameter optimizer that was loaded from the file.
        """
        # First, load the checkpoint.
        if params.hyperparameters.rdb_storage is None:
            with open(file_path, "rb") as handle:
                loaded_study = pickle.load(handle)
            # Now, create the Trainer class with it.
            loaded_hyperopt = HyperOptOptuna(params, data)
            loaded_hyperopt.study = loaded_study
        else:
            loaded_hyperopt = HyperOptOptuna(params, data)
        return loaded_hyperopt 
    def __get_number_of_completed_trials(self, study):
        """
        Get the number of completed trials from a study.
        Parameters
        ----------
        study : optuna.study.Study
            Study from which the number of completed trials should be
            extracted.
        Returns
        -------
        number_of_completed_trials : int
            Number of completed trials.
        """
        # How to calculate this depends on whether or not a heartbeat was
        # used. If one was used, then both COMPLETE and RUNNING trials
        # Can be taken into account, as it can be expected that RUNNING
        # trials will actually finish. If no heartbeat is used,
        # then RUNNING trials might be Zombie trials.
        # See
        if self.params.hyperparameters.rdb_storage_heartbeat is None:
            return len(
                [
                    t
                    for t in study.trials
                    if t.state == optuna.trial.TrialState.COMPLETE
                ]
            )
        else:
            return len(
                [
                    t
                    for t in study.trials
                    if t.state == optuna.trial.TrialState.COMPLETE
                    or t.state == optuna.trial.TrialState.RUNNING
                ]
            )
    def __check_stopping(self, study, trial):
        """
        Check if this trial was already the maximum number of trials.
        If so, stop the study.
        Parameters
        ----------
        study : optuna.study.Study
            Study in which the trial is running.
        trial : optuna.trial.FrozenTrial
            Trial for which the stopping condition should be tested.
        """
        # How to check for this depends on whether or not a heartbeat was
        # used. If one was used, then both COMPLETE and RUNNING trials
        # Can be taken into account, as it can be expected that RUNNING
        # trials will actually finish. If no heartbeat is used,
        # then RUNNING trials might be Zombie trials.
        # See
        # https://github.com/optuna/optuna/issues/1883#issuecomment-841844834
        # https://github.com/optuna/optuna/issues/1883#issuecomment-842106950
        completed_trials = self.__get_number_of_completed_trials(study)
        if completed_trials >= self.params.hyperparameters.n_trials:
            self.study.stop()
        # Only check if there are trials to be checked.
        if completed_trials > 0:
            if (
                self.params.hyperparameters.number_bad_trials_before_stopping
                is not None
                and self.params.hyperparameters.number_bad_trials_before_stopping
                > 0
            ):
                if (
                    trial.number - self.study.best_trial.number
                    >= self.params.hyperparameters.number_bad_trials_before_stopping
                ):
                    printout(
                        "No new best trial found in",
                        self.params.hyperparameters.number_bad_trials_before_stopping,
                        "attempts, stopping the study.",
                    )
                    self.study.stop()
    def __create_checkpointing(self, study, trial):
        """
        Create a checkpoint of optuna study, if necessary.
        This is done based on an internal checkpoint counter.
        Parameters
        ----------
        study : optuna.study.Study
            Study in which the trial is running.
        trial : optuna.trial.FrozenTrial
            Trial for which the checkpoint may be created.
        """
        self._checkpoint_counter += 1
        need_to_checkpoint = False
        if (
            self._checkpoint_counter
            >= self.params.hyperparameters.checkpoints_each_trial
            and self.params.hyperparameters.checkpoints_each_trial > 0
        ):
            need_to_checkpoint = True
            printout(
                str(self.params.hyperparameters.checkpoints_each_trial)
                + " trials have passed, creating a "
                "checkpoint for hyperparameter "
                "optimization.",
                min_verbosity=0,
            )
        if (
            self.params.hyperparameters.checkpoints_each_trial < 0
            and self.__get_number_of_completed_trials(study) > 0
        ):
            if trial.number == study.best_trial.number:
                need_to_checkpoint = True
                printout(
                    "Best trial is " + str(trial.number) + ", creating a "
                    "checkpoint for it.",
                    min_verbosity=0,
                )
        if need_to_checkpoint is True:
            # We need to create a checkpoint!
            self._checkpoint_counter = 0
            self._save_params_and_scaler()
            # The study only has to be saved if the no RDB storage is used.
            if self.params.hyperparameters.rdb_storage is None:
                hyperopt_name = (
                    self.params.hyperparameters.checkpoint_name
                    + "_hyperopt.pth"
                )
                with open(hyperopt_name, "wb") as handle:
                    pickle.dump(self.study, handle, protocol=4)