trainer

Trainer class for training a network.

class Trainer(params, network, data, optimizer_dict=None)[source]

Bases: Runner

A class for training a neural network.

Parameters:
classmethod load_run(run_name, path='./', zip_run=True, params_format='json', load_runner=True, prepare_data=True)[source]

Load a run.

Parameters:
  • run_name (str) – Name under which the run is saved.

  • path (str) – Path where the run is saved.

  • zip_run (bool) – If True, MALA will attempt to load from a .zip file. If False, then separate files will be attempted to be loaded.

  • params_format (str) – Can be “json” or “pkl”, depending on what was saved by the model. Default is “json”.

  • load_runner (bool) – If True, a Runner object will be created/loaded for further use.

  • prepare_data (bool) – If True, the data will be loaded into memory. This is needed when continuing a model training.

Returns:

  • loaded_params (mala.common.parameters.Parameters) – The Parameters saved to file.

  • loaded_network (mala.network.network.Network) – The network saved to file.

  • new_datahandler (mala.datahandling.data_handler.DataHandler) – The data handler reconstructed from file.

  • new_trainer (Trainer) – (Optional) The runner reconstructed from file. For Tester and Predictor class, this is just a newly instantiated object.

classmethod run_exists(run_name, params_format='json', zip_run=True)[source]

Check if a hyperparameter optimization checkpoint exists.

Returns True if it does.

Parameters:
  • run_name (string) – Name of the checkpoint.

  • params_format (bool) – Save format of the parameters.

Returns:

checkpoint_exists – True if the checkpoint exists, False otherwise.

Return type:

bool

train_network()[source]

Train a network using data given by a DataHandler.