tester

Tester class for testing a network.

class Tester(params, network, data, observables_to_test=['ldos'], output_format='list')[source]

Bases: Runner

A class for testing a neural network.

It enables easy inference throughout a test set.

Parameters:
  • params (mala.common.parametes.Parameters) – Parameters used to create this Tester object.

  • network (mala.network.network.Network) – Network which is being tested.

  • data (mala.datahandling.data_handler.DataHandler) – DataHandler holding the test data.

  • observables_to_test (list) –

    List of observables to test. Supported are:

    • ”ldos”: Calculate the MSE loss of the LDOS.

    • ”band_energy”: Band energy error

    • ”band_energy_full”: Band energy absolute values (only works with list, as both actual and predicted are returned)

    • ”total_energy”: Total energy error

    • ”total_energy_full”: Total energy absolute values (only works with list, as both actual and predicted are returned)

    • ”number_of_electrons”: Number of electrons (Fermi energy is not determined dynamically for this quantity.

    • ”density”: MAPE of the density prediction

    • ”dos”: MAPE of the DOS prediction

  • output_format (string) – Can be “list” or “mae”. If “list”, then a list of results across all snapshots is returned. If “mae”, then the MAE across all snapshots will be calculated and returned.

get_energy_targets_and_predictions(snapshot_number, data_type='te')[source]

Get the energy targets and predictions for a single snapshot.

Parameters:
  • snapshot_number (int) – Snapshot which to test.

  • data_type (str) – ‘tr’, ‘va’, or ‘te’ indicating the partition to be tested

Returns:

results – A dictionary containing the errors for the selected observables.

Return type:

dict

predict_targets(snapshot_number, data_type='te')[source]

Get actual and predicted energy outputs for a snapshot.

Parameters:
  • snapshot_number (int) – Snapshot for which the prediction is done.

  • data_type (str) – ‘tr’, ‘va’, or ‘te’ indicating the partition to be tested

Returns:

  • actual_outputs (numpy.ndarray) – Actual outputs for snapshot.

  • predicted_outputs (numpy.ndarray) – Precicted outputs for snapshot.

test_all_snapshots()[source]

Test the selected observables for all snapshots.

Returns:

results – A dictionary containing the errors for the selected observables, either as list or MAE.

Return type:

dict

test_snapshot(snapshot_number, data_type='te')[source]

Test the selected observables for a single snapshot.

Parameters:
  • snapshot_number (int) – Snapshot which to test.

  • data_type (str) – ‘tr’, ‘va’, or ‘te’ indicating the partition to be tested

Returns:

results – A dictionary containing the errors for the selected observables.

Return type:

dict