dlt.train¶
Base Classes¶
-
class
dlt.train.
BaseTrainer
¶ Generic Base trainer object to inherit functionality from.
-
__call__
(loader)¶ Performs an epoch of training or validation.
Parameters: loader (iterable) – The data loader.
-
cpu
(device=0)¶ Sets the trainer to CPU mode
-
cuda
(device=0)¶ Sets the trainer to GPU mode.
If flagged, the data is cast to GPU before every iteration after being retrieved from the loader.
-
eval
()¶ Sets the trainer and models to inference mode
-
iterate
(loader)¶ Performs an epoch of training or validation.
Parameters: loader (iterable) – The data loader.
-
load_state_dict
(state_dict)¶ Loads the trainers state.
Parameters: state_dict (dict) – scheduler state. Should be an object returned from a call to state_dict()
.
-
loss_names
(training=None)¶ Returns the name(s)/key(s) of the training or validation loss(es).
Parameters: training (bool, optional) – If provided then the training or validation losses are returned if True or False respectively. If not provided the current mode loss is returned.
-
loss_names_training
()¶ Returns the name(s)/key(s) of the training loss(es).
-
loss_names_validation
()¶ Returns the name(s)/key(s) of the validation loss(es).
-
state_dict
()¶ Returns the state of the trainer as a
dict
.It contains an entry for every variable in self.__dict__ which is not the one of the models or optimizers.
-
train
()¶ Sets the trainer and models to training mode
-
-
class
dlt.train.
GANBaseTrainer
(generator, discriminator, g_optimizer, d_optimizer, d_iter)¶ Base Trainer to inherit functionality from for training Generative Adversarial Networks.
Parameters: - generator (nn.Module) – The generator network.
- discriminator (nn.Module) – The discriminator network.
- g_optimizer (torch.optim.Optimizer) – Generator Optimizer.
- d_optimizer (torch.optim.Optimizer) – Discriminator Optimizer.
- d_iter (python:int) – Number of discriminator steps per generator step.
Inherits from
dlt.train.BaseTrainer
Vanilla¶
-
class
dlt.train.
VanillaTrainer
(model, criterion, optimizer)¶ Training of a network using a criterion/loss function.
Parameters: - model (nn.Module) – The network to train.
- criterion (callable) – The function to optimize.
- optimizer (torch.optim.Optimizer) – A torch Optimizer.
Each iteration returns the mini-batch and a tuple containing:
- The model prediction.
- A dictionary with the training_loss or validation_loss (along with the partial losses, if criterion returns a dictionary).
Example
>>> trainer = dlt.train.VanillaTrainer(my_model, nn.L1Loss(), my_optimizer) >>> # Training mode >>> trainer.train() >>> for batch, (prediction, loss) in trainer(train_data_loader): >>> print(loss['training_loss']) >>> # Validation mode >>> trainer.eval() >>> for batch, (prediction, loss) in trainer(valid_data_loader): >>> print(loss['validation_loss'])
Note
If the criterion returns a dict of (named) losses, then they are added together to backpropagate. The total is returned along with all the partial losses.
Vanilla GAN¶
-
class
dlt.train.
VanillaGANTrainer
(generator, discriminator, g_optimizer, d_optimizer, d_iter=1)¶ Generative Adversarial Networks trainer.
Parameters: - generator (nn.Module) – The generator network.
- discriminator (nn.Module) – The discriminator network.
- g_optimizer (torch.optim.Optimizer) – Generator Optimizer.
- d_optimizer (torch.optim.Optimizer) – Discriminator Optimizer.
- d_iter (python:int, optional) – Number of discriminator steps per generator step (default 1).
Each iteration returns the mini-batch and a tuple containing:
The generator prediction.
A dictionary containing a d_loss (not when validating) and a g_loss dictionary (only if a generator step is performed):
- d_loss contains: d_loss, real_loss, and fake_loss.
- g_loss contains: g_loss.
Example
>>> trainer = dlt.train.VanillaGANTrainer(gen, disc, g_optim, d_optim) >>> # Training mode >>> trainer.train() >>> for batch, (prediction, loss) in trainer(train_data_loader): >>> print(loss['d_loss']['d_loss'])
Warning
This trainer uses BCEWithLogitsLoss, which means that the discriminator must NOT have a sigmoid at the end.
WGAN-GP¶
-
class
dlt.train.
WGANGPTrainer
(generator, discriminator, g_optimizer, d_optimizer, lambda_gp, d_iter=1)¶ Wasserstein GAN Trainer with gradient penalty.
Parameters: - generator (nn.Module) – The generator network.
- discriminator (nn.Module) – The discriminator network.
- g_optimizer (torch.optim.Optimizer) – Generator Optimizer.
- d_optimizer (torch.optim.Optimizer) – Discriminator Optimizer.
- lambda_gp (python:float) – Weight of gradient penalty.
- d_iter (python:int, optional) – Number of discriminator steps per generator step (default 1).
Each iteration returns the mini-batch and a tuple containing:
The generator prediction.
A dictionary containing a d_loss (not when validating) and a g_loss dictionary (only if a generator step is performed):
- d_loss contains: d_loss, w_loss, and gp.
- g_loss contains: g_loss.
Example
>>> trainer = dlt.train.WGANGPTrainer(gen, disc, g_optim, d_optim, lambda_gp) >>> # Training mode >>> trainer.train() >>> for batch, (prediction, loss) in trainer(train_data_loader): >>> print(loss['d_loss']['w_loss'])
WGAN-CT¶
-
class
dlt.train.
WGANCTTrainer
(generator, discriminator, g_optimizer, d_optimizer, lambda_gp, m_ct, lambda_ct, d_iter=1)¶ Wasserstein GAN Trainer with gradient penalty and correction term.
From Improving the Improved Training of Wasserstein GANs: A Consistency Term and Its Dual Effect.
Parameters: - generator (nn.Module) – The generator network.
- discriminator (nn.Module) – The discriminator network.
- g_optimizer (torch.optim.Optimizer) – Generator Optimizer.
- d_optimizer (torch.optim.Optimizer) – Discriminator Optimizer.
- lambda_gp (python:float) – Weight of gradient penalty.
- m_ct (python:float) – Constant bound for consistency term.
- lambda_ct (python:float) – Weight of consistency term.
- d_iter (python:int, optional) – Number of discriminator steps per generator step (default 1).
Each iteration returns the mini-batch and a tuple containing:
The generator prediction.
A dictionary containing a d_loss (not when validating) and a g_loss dictionary (only if a generator step is performed):
- d_loss contains: d_loss, w_loss, gp and ct.
- g_loss contains: g_loss.
Warning
The discriminator forward function needs to be able to accept an optional bool argument correction_term. When set to true, the forward function must add dropout noise to the model and return a tuple containing the second to last output of the discriminator along with the final output.
Example
>>> trainer = dlt.train.WGANCTTrainer(gen, disc, g_optim, d_optim, lambda_gp, m_ct, lambda_ct) >>> # Training mode >>> trainer.train() >>> for batch, (prediction, loss) in trainer(train_data_loader): >>> print(loss['d_loss']['w_loss'])
BEGAN¶
-
class
dlt.train.
BEGANTrainer
(generator, discriminator, g_optimizer, d_optimizer, lambda_k, gamma, d_iter=1)¶ Boundary Equilibrium GAN trainer.
Parameters: - generator (nn.Module) – The generator network.
- discriminator (nn.Module) – The discriminator network.
- g_optimizer (torch.optim.Optimizer) – Generator Optimizer.
- d_optimizer (torch.optim.Optimizer) – Discriminator Optimizer.
- lambda_k (python:float) – Learning rate of k parameter.
- gamma (python:float) – Diversity ratio.
- d_iter (python:int) – Number of discriminator steps per generator step.
Each iteration returns the mini-batch and a tuple containing:
The generator prediction.
A dictionary containing a d_loss (not when validating) and a g_loss dictionary (only if a generator step is performed):
- d_loss contains: d_loss, real_loss, fake_loss, k, balance, and measure.
- g_loss contains: g_loss.
Example:
>>> trainer = dlt.train.BEGANTrainer(gen, disc, g_optim, d_optim, lambda_k, gamma) >>> # Training mode >>> trainer.train() >>> for batch, (prediction, loss) in trainer(train_data_loader): >>> print(loss['d_loss']['measure'])
Fisher-GAN¶
-
class
dlt.train.
FisherGANTrainer
(generator, discriminator, g_optimizer, d_optimizer, rho, d_iter=1)¶ Fisher GAN trainer.
Parameters: - generator (nn.Module) – The generator network.
- discriminator (nn.Module) – The discriminator network.
- g_optimizer (torch.optim.Optimizer) – Generator Optimizer.
- d_optimizer (torch.optim.Optimizer) – Discriminator Optimizer.
- rho (python:float) – Quadratic penalty weight.
- d_iter (python:int, optional) – Number of discriminator steps per generator step (default 1).
Each iteration returns the mini-batch and a tuple containing:
The generator prediction.
A dictionary containing a d_loss (not when validating) and a g_loss dictionary (only if a generator step is performed):
- d_loss contains: ipm_enum, ipm_denom, ipm_ratio, d_loss, constraint, epf, eqf, epf2, eqf2 and lagrange.
- g_loss contains: g_loss.
Example
>>> trainer = dlt.train.FisherGANTrainer(gen, disc, g_optim, d_optim, rho) >>> # Training mode >>> trainer.train() >>> for batch, (prediction, loss) in trainer(train_data_loader): >>> print(loss['d_loss']['constraint'])