tester
Tester class for testing a network.
- class Tester(params, network, data, observables_to_test=['ldos'], output_format='list')[source]
Bases:
Runner
A class for testing a neural network.
It enables easy inference throughout a test set.
- Parameters:
params (mala.common.parametes.Parameters) – Parameters used to create this Tester object.
network (mala.network.network.Network) – Network which is being tested.
data (mala.datahandling.data_handler.DataHandler) – DataHandler holding the test data.
observables_to_test (list) –
List of observables to test. Supported are:
”ldos”: Calculate the MSE loss of the LDOS.
”band_energy”: Band energy error
”band_energy_full”: Band energy absolute values (only works with list, as both actual and predicted are returned)
”total_energy”: Total energy error
”total_energy_full”: Total energy absolute values (only works with list, as both actual and predicted are returned)
”number_of_electrons”: Number of electrons (Fermi energy is not determined dynamically for this quantity.
”density”: MAPE of the density prediction
”dos”: MAPE of the DOS prediction
output_format (string) – Can be “list” or “mae”. If “list”, then a list of results across all snapshots is returned. If “mae”, then the MAE across all snapshots will be calculated and returned.
- get_energy_targets_and_predictions(snapshot_number, data_type='te')[source]
Get the energy targets and predictions for a single snapshot.
- Parameters:
snapshot_number (int) – Snapshot which to test.
data_type (str) – ‘tr’, ‘va’, or ‘te’ indicating the partition to be tested
- Returns:
results – A dictionary containing the errors for the selected observables.
- Return type:
dict
- predict_targets(snapshot_number, data_type='te')[source]
Get actual and predicted energy outputs for a snapshot.
- Parameters:
snapshot_number (int) – Snapshot for which the prediction is done.
data_type (str) – ‘tr’, ‘va’, or ‘te’ indicating the partition to be tested
- Returns:
actual_outputs (numpy.ndarray) – Actual outputs for snapshot.
predicted_outputs (numpy.ndarray) – Precicted outputs for snapshot.
- test_all_snapshots()[source]
Test the selected observables for all snapshots.
- Returns:
results – A dictionary containing the errors for the selected observables, either as list or MAE.
- Return type:
dict
- test_snapshot(snapshot_number, data_type='te')[source]
Test the selected observables for a single snapshot.
- Parameters:
snapshot_number (int) – Snapshot which to test.
data_type (str) – ‘tr’, ‘va’, or ‘te’ indicating the partition to be tested
- Returns:
results – A dictionary containing the errors for the selected observables.
- Return type:
dict