predictor
Predictor class.
- class Predictor(params, network, data)[source]
Bases:
Runner
A class for running predictions using a neural network.
It enables production-level inference.
- Parameters:
params (mala.common.parametes.Parameters) – Parameters used to create this Predictor object.
network (mala.network.network.Network) – Network used for predictions.
data (mala.datahandling.data_handler.DataHandler) – DataHandler, in this case not directly holding data, but serving as an interface to Target and Descriptor objects.
- target_calculator
Target calculator used for predictions. Can be used for further processing.
- predict_for_atoms(atoms, gather_ldos=False, temperature=None)[source]
Get predicted LDOS for an atomic configuration.
- Parameters:
atoms (ase.Atoms) – ASE atoms for which the prediction should be done.
gather_ldos (bool) – Only important if MPI is used. If True, all descriptors are gathered on rank 0, and the pass is performed there. Helpful for using multiple CPUs for descriptor calculations and only one for network pass.
temperature (float) – If not None, this temperature value will be set in the internal target calculator and can be used in subsequent integrations. If None, the default temperature loaded from the model will be used. Temperature has to be given in K.
- Returns:
predicted_ldos – Precicted LDOS for these atomic positions.
- Return type:
numpy.array
- predict_from_qeout(path_to_file, gather_ldos=False)[source]
Get predicted LDOS for the atomic configuration of a QE.out file.
- Parameters:
path_to_file (string) – Path from which to read the atomic configuration.
gather_ldos (bool) – Only important if MPI is used. If True, all descriptors are gathered on rank 0, and the pass is performed there. Helpful for using multiple CPUs for descriptor calculations and only one for network pass.
- Returns:
predicted_ldos – Precicted LDOS for these atomic positions.
- Return type:
numpy.array