"""Hyperparameter optimizer using optuna."""importpickleimportoptunafrommala.common.parametersimportprintoutfrommala.network.hyper_optimportHyperOptfrommala.network.objective_baseimportObjectiveBasefrommala.network.naswot_prunerimportNASWOTPrunerfrommala.network.multi_training_prunerimportMultiTrainingPrunerfrommala.common.parallelizerimportparallel_warn
[docs]classHyperOptOptuna(HyperOpt):"""Hyperparameter optimizer using Optuna. Parameters ---------- params : mala.common.parameters.Parameters Parameters used to create this hyperparameter optimizer. data : mala.datahandling.data_handler.DataHandler DataHandler holding the data for the hyperparameter optimization. use_pkl_checkpoints : bool If true, .pkl checkpoints will be created. Attributes ---------- params : mala.common.parameters.Parameters MALA Parameters object. objective : mala.network.objective_base.ObjectiveBase MALA objective to be optimized, i.e., a MALA NN model training. study : optuna.study.Study An Optuna study used to collect the results of the hyperparameter optimization. """def__init__(self,params,data,use_pkl_checkpoints=False):super(HyperOptOptuna,self).__init__(params,data,use_pkl_checkpoints=use_pkl_checkpoints)self.params=params# Make the sample behave in a reproducible way, if so specified by# the user.sampler=optuna.samplers.TPESampler(seed=params.manual_seed,multivariate=params.hyperparameters.use_multivariate,)# See if the user specified a pruner.pruner=Noneifself.params.hyperparameters.prunerisnotNone:ifself.params.hyperparameters.pruner=="naswot":pruner=NASWOTPruner(self.params,data)elifself.params.hyperparameters.pruner=="multi_training":ifself.params.hyperparameters.number_training_per_trial>1:pruner=MultiTrainingPruner(self.params)else:printout("MultiTrainingPruner requested, but only one ""training""per trial specified; Skipping pruner creation.")else:raiseException("Invalid pruner type selected.")# Create the study.ifself.params.hyperparameters.rdb_storageisNone:self.study=optuna.create_study(direction=self.params.hyperparameters.direction,sampler=sampler,study_name=self.params.hyperparameters.study_name,pruner=pruner,)else:ifself.params.hyperparameters.study_nameisNone:raiseException("If RDB storage is used, a name for the study ""has to be provided.")if"sqlite"inself.params.hyperparameters.rdb_storage:engine_kwargs={"connect_args":{"timeout":self.params.hyperparameters.sqlite_timeout}}else:engine_kwargs=Nonerdb_storage=optuna.storages.RDBStorage(url=self.params.hyperparameters.rdb_storage,heartbeat_interval=self.params.hyperparameters.rdb_storage_heartbeat,engine_kwargs=engine_kwargs,)self.study=optuna.create_study(direction=self.params.hyperparameters.direction,sampler=sampler,study_name=self.params.hyperparameters.study_name,storage=rdb_storage,load_if_exists=True,pruner=pruner,)self._checkpoint_counter=0
[docs]defperform_study(self):""" Perform the study, i.e. the optimization. This is done by sampling a certain subset of network architectures. In this case, optuna is used. Returns ------- best_trial_loss : float Loss of the best trial. """# The parameters could have changed.self.objective=ObjectiveBase(self.params,self._data_handler)# Fill callback list based on user checkpoint wishes.callback_list=[self.__check_stopping]ifself.params.hyperparameters.checkpoints_each_trial!=0:callback_list.append(self.__create_checkpointing)self.study.optimize(self.objective,n_trials=None,callbacks=callback_list)# Return the best lost value we could achieve.returnself.study.best_value
[docs]defset_optimal_parameters(self):""" Set the optimal parameters found in the present study. The parameters will be written to the parameter object with which the hyperparameter optimizer was created. """# Parse the parameters from the best trial.self.objective.parse_trial_optuna(self.study.best_trial)
[docs]defget_trials_from_study(self):""" Return the trials from the last study. Only returns completed trials. Returns ------- last_trials: list A list of optuna.FrozenTrial objects. """returnself.study.get_trials(states=(optuna.trial.TrialState.COMPLETE,))
[docs]@staticmethoddefrequeue_zombie_trials(study_name,rdb_storage):""" Put zombie trials back into the queue to be investigated. When using Optuna with scheduling systems in HPC infrastructure, zombie trials can occur. These are trials that are still marked as "RUNNING", but are, in actuality, dead, since the HPC job ended. This function takes a saved hyperparameter study, and puts all "RUNNING" trials als "WAITING". Upon the next execution from checkpoint, they will be executed. BE CAREFUL! DO NOT USE APPLY THIS TO A RUNNING STUDY, IT WILL MESS THE STUDY UP! ONLY USE THIS ONCE ALL JOBS HAVE FINISHED, TO CLEAN UP, AND THEN RESUBMIT! Parameters ---------- rdb_storage : string Adress of the RDB storage to be cleaned. study_name : string Name of the study in the storage. Same as the checkpoint name. """study_to_clean=optuna.load_study(study_name=study_name,storage=rdb_storage)parallel_warn("WARNING: Your about to clean/requeue a study."" This operation should not be done to an already"" running study.")trials=study_to_clean.get_trials()cleaned_trials=[]fortrialintrials:iftrial.state==optuna.trial.TrialState.RUNNING:kwds=dict(trial_id=trial._trial_id,state=optuna.trial.TrialState.WAITING,)ifhasattr(study_to_clean._storage,"set_trial_state"):# Optuna 2.xstudy_to_clean._storage.set_trial_state(**kwds)else:# Optuna 3.xstudy_to_clean._storage.set_trial_state_values(values=None,**kwds)cleaned_trials.append(trial.number)printout("Cleaned trials: ",cleaned_trials,min_verbosity=0)
[docs]@classmethoddefresume_checkpoint(cls,checkpoint_name,alternative_storage_path=None,no_data=False,use_pkl_checkpoints=False,):""" Prepare resumption of hyperparameter optimization from a checkpoint. Please note that to actually resume the optimization, HyperOptOptuna.perform_study() still has to be called. Parameters ---------- checkpoint_name : string Name of the checkpoint from which the checkpoint is loaded. alternative_storage_path: string Alternative storage string to load the study from. For applications on an HPC cluster it might be necessary to slightly modify the storage path between runs, since the SQL server might be running on different nodes each time. no_data : bool If True, the data won't actually be loaded into RAM or scaled. This can be useful for cases where a checkpoint is loaded for analysis purposes. use_pkl_checkpoints : bool If true, .pkl checkpoints will be loaded. Returns ------- loaded_params : mala.common.parameters.Parameters The Parameters saved in the checkpoint. new_datahandler : mala.datahandling.data_handler.DataHandler The data handler reconstructed from the checkpoint. new_hyperopt : HyperOptOptuna The hyperparameter optimizer reconstructed from the checkpoint. """loaded_params,new_datahandler,optimizer_name=(cls._resume_checkpoint(checkpoint_name,no_data=no_data,use_pkl_checkpoints=use_pkl_checkpoints,))ifalternative_storage_pathisnotNone:loaded_params.hyperparameters.rdb_storage=(alternative_storage_path)new_hyperopt=HyperOptOptuna.load_from_file(loaded_params,optimizer_name,new_datahandler)returnloaded_params,new_datahandler,new_hyperopt
[docs]@classmethoddefload_from_file(cls,params,file_path,data):""" Load a hyperparameter optimizer from a file. Parameters ---------- params : mala.common.parameters.Parameters Parameters object with which the hyperparameter optimizer should be created Has to be compatible with data. file_path : string Path to the file from which the hyperparameter optimizer should be loaded. data : mala.datahandling.data_handler.DataHandler DataHandler holding the training data. Returns ------- loaded_trainer : Network The hyperparameter optimizer that was loaded from the file. """# First, load the checkpoint.ifparams.hyperparameters.rdb_storageisNone:withopen(file_path,"rb")ashandle:loaded_study=pickle.load(handle)# Now, create the Trainer class with it.loaded_hyperopt=HyperOptOptuna(params,data)loaded_hyperopt.study=loaded_studyelse:loaded_hyperopt=HyperOptOptuna(params,data)returnloaded_hyperopt
def__get_number_of_completed_trials(self,study):""" Get the number of completed trials from a study. Parameters ---------- study : optuna.study.Study Study from which the number of completed trials should be extracted. Returns ------- number_of_completed_trials : int Number of completed trials. """# How to calculate this depends on whether or not a heartbeat was# used. If one was used, then both COMPLETE and RUNNING trials# Can be taken into account, as it can be expected that RUNNING# trials will actually finish. If no heartbeat is used,# then RUNNING trials might be Zombie trials.# Seeifself.params.hyperparameters.rdb_storage_heartbeatisNone:returnlen([tfortinstudy.trialsift.state==optuna.trial.TrialState.COMPLETE])else:returnlen([tfortinstudy.trialsift.state==optuna.trial.TrialState.COMPLETEort.state==optuna.trial.TrialState.RUNNING])def__check_stopping(self,study,trial):""" Check if this trial was already the maximum number of trials. If so, stop the study. Parameters ---------- study : optuna.study.Study Study in which the trial is running. trial : optuna.trial.FrozenTrial Trial for which the stopping condition should be tested. """# How to check for this depends on whether or not a heartbeat was# used. If one was used, then both COMPLETE and RUNNING trials# Can be taken into account, as it can be expected that RUNNING# trials will actually finish. If no heartbeat is used,# then RUNNING trials might be Zombie trials.# See# https://github.com/optuna/optuna/issues/1883#issuecomment-841844834# https://github.com/optuna/optuna/issues/1883#issuecomment-842106950completed_trials=self.__get_number_of_completed_trials(study)ifcompleted_trials>=self.params.hyperparameters.n_trials:self.study.stop()# Only check if there are trials to be checked.ifcompleted_trials>0:if(self.params.hyperparameters.number_bad_trials_before_stoppingisnotNoneandself.params.hyperparameters.number_bad_trials_before_stopping>0):if(trial.number-self.study.best_trial.number>=self.params.hyperparameters.number_bad_trials_before_stopping):printout("No new best trial found in",self.params.hyperparameters.number_bad_trials_before_stopping,"attempts, stopping the study.",)self.study.stop()def__create_checkpointing(self,study,trial):""" Create a checkpoint of optuna study, if necessary. This is done based on an internal checkpoint counter. Parameters ---------- study : optuna.study.Study Study in which the trial is running. trial : optuna.trial.FrozenTrial Trial for which the checkpoint may be created. """self._checkpoint_counter+=1need_to_checkpoint=Falseif(self._checkpoint_counter>=self.params.hyperparameters.checkpoints_each_trialandself.params.hyperparameters.checkpoints_each_trial>0):need_to_checkpoint=Trueprintout(str(self.params.hyperparameters.checkpoints_each_trial)+" trials have passed, creating a ""checkpoint for hyperparameter ""optimization.",min_verbosity=0,)if(self.params.hyperparameters.checkpoints_each_trial<0andself.__get_number_of_completed_trials(study)>0):iftrial.number==study.best_trial.number:need_to_checkpoint=Trueprintout("Best trial is "+str(trial.number)+", creating a ""checkpoint for it.",min_verbosity=0,)ifneed_to_checkpointisTrue:# We need to create a checkpoint!self._checkpoint_counter=0self._save_params_and_scaler()# The study only has to be saved if the no RDB storage is used.ifself.params.hyperparameters.rdb_storageisNone:hyperopt_name=(self.params.hyperparameters.checkpoint_name+"_hyperopt.pth")withopen(hyperopt_name,"wb")ashandle:pickle.dump(self.study,handle,protocol=4)