trainer
Trainer class for training a network.
- class Trainer(params, network, data, optimizer_dict=None)[source]
Bases:
Runner
A class for training a neural network.
- Parameters:
params (mala.common.parametes.Parameters) – Parameters used to create this Trainer object.
network (mala.network.network.Network) – Network which is being trained.
data (mala.datahandling.data_handler.DataHandler) – DataHandler holding the training data.
use_pkl_checkpoints (bool) – If true, .pkl checkpoints will be created.
- classmethod load_run(run_name, path='./', zip_run=True, params_format='json', load_runner=True, prepare_data=True)[source]
Load a run.
- Parameters:
run_name (str) – Name under which the run is saved.
path (str) – Path where the run is saved.
zip_run (bool) – If True, MALA will attempt to load from a .zip file. If False, then separate files will be attempted to be loaded.
params_format (str) – Can be “json” or “pkl”, depending on what was saved by the model. Default is “json”.
load_runner (bool) – If True, a Runner object will be created/loaded for further use.
prepare_data (bool) – If True, the data will be loaded into memory. This is needed when continuing a model training.
- Returns:
loaded_params (mala.common.parameters.Parameters) – The Parameters saved to file.
loaded_network (mala.network.network.Network) – The network saved to file.
new_datahandler (mala.datahandling.data_handler.DataHandler) – The data handler reconstructed from file.
new_trainer (Trainer) – (Optional) The runner reconstructed from file. For Tester and Predictor class, this is just a newly instantiated object.
- classmethod run_exists(run_name, params_format='json', zip_run=True)[source]
Check if a hyperparameter optimization checkpoint exists.
Returns True if it does.
- Parameters:
run_name (string) – Name of the checkpoint.
params_format (bool) – Save format of the parameters.
- Returns:
checkpoint_exists – True if the checkpoint exists, False otherwise.
- Return type:
bool