network
Neural network for MALA.
- class FeedForwardNet(params: Parameters | None = None)[source]
Bases:
Network
Initialize this network as a feed-forward network.
- class Network(params: Parameters | None = None)[source]
Bases:
Module
Central network class for this framework, based on pytorch.nn.Module.
The correct type of neural network will automatically be instantiated by this class if possible. You can also instantiate the desired network directly by calling upon the subclass.
- Parameters:
params (mala.common.parametes.Parameters) – Parameters used to create this neural network.
- loss_func
Loss function.
- Type:
function
- mini_batch_size
Size of mini batches propagated through network.
- Type:
int
- number_of_layers
Number of NN layers.
- Type:
int
- params
MALA neural network parameters.
- Type:
mala.common.parametes.ParametersNetwork
- use_ddp
If True, the torch distributed data parallel formalism will be used.
- Type:
bool
- calculate_loss(output, target)[source]
Calculate the loss for a predicted output and target.
- Parameters:
output (torch.Tensor) – Predicted output.
target (torch.Tensor.) – Actual output.
- Returns:
loss_val – Loss value for output and target.
- Return type:
float
- do_prediction(array)[source]
Predict the output values for an input array.
Interface to do predictions. The data put in here is assumed to be a scaled torch.Tensor and in the right units. Be aware that this will pass the entire array through the network, which might be very demanding in terms of RAM.
- Parameters:
array (torch.Tensor) – Input array for which the prediction is to be performed.
- Returns:
predicted_array – Predicted outputs of array.
- Return type:
torch.Tensor
- abstract forward(inputs)[source]
Abstract method. To be implemented by the derived class.
- Parameters:
inputs (torch.Tensor) – Torch tensor to be propagated.
- classmethod load_from_file(params, file)[source]
Load a network from a file.
- Parameters:
params (mala.common.parameters.Parameters) – Parameters object with which the network should be created. Has to be compatible to the network architecture. This is usually enforced by using the same Parameters object (and saving/loading it to)
file (string or ZipExtFile) – Path to the file from which the network should be loaded.
- Returns:
loaded_network – The network that was loaded from the file.
- Return type: