dlt.train

Base Classes

class dlt.train.BaseTrainer

Generic Base trainer object to inherit functionality from.

__call__(loader)

Performs an epoch of training or validation.

Parameters:loader (iterable) – The data loader.
cpu(device=0)

Sets the trainer to CPU mode

cuda(device=0)

Sets the trainer to GPU mode.

If flagged, the data is cast to GPU before every iteration after being retrieved from the loader.

eval()

Sets the trainer and models to inference mode

iterate(loader)

Performs an epoch of training or validation.

Parameters:loader (iterable) – The data loader.
load_state_dict(state_dict)

Loads the trainers state.

Parameters:state_dict (dict) – scheduler state. Should be an object returned from a call to state_dict().
loss_names(training=None)

Returns the name(s)/key(s) of the training or validation loss(es).

Parameters:training (bool, optional) – If provided then the training or validation losses are returned if True or False respectively. If not provided the current mode loss is returned.
loss_names_training()

Returns the name(s)/key(s) of the training loss(es).

loss_names_validation()

Returns the name(s)/key(s) of the validation loss(es).

state_dict()

Returns the state of the trainer as a dict.

It contains an entry for every variable in self.__dict__ which is not the one of the models or optimizers.

train()

Sets the trainer and models to training mode

class dlt.train.GANBaseTrainer(generator, discriminator, g_optimizer, d_optimizer, d_iter)

Base Trainer to inherit functionality from for training Generative Adversarial Networks.

Parameters:
  • generator (nn.Module) – The generator network.
  • discriminator (nn.Module) – The discriminator network.
  • g_optimizer (torch.optim.Optimizer) – Generator Optimizer.
  • d_optimizer (torch.optim.Optimizer) – Discriminator Optimizer.
  • d_iter (python:int) – Number of discriminator steps per generator step.

Inherits from dlt.train.BaseTrainer

Vanilla

class dlt.train.VanillaTrainer(model, criterion, optimizer)

Training of a network using a criterion/loss function.

Parameters:
  • model (nn.Module) – The network to train.
  • criterion (callable) – The function to optimize.
  • optimizer (torch.optim.Optimizer) – A torch Optimizer.

Each iteration returns the mini-batch and a tuple containing:

  • The model prediction.
  • A dictionary with the training_loss or validation_loss (along with the partial losses, if criterion returns a dictionary).

Example

>>> trainer = dlt.train.VanillaTrainer(my_model, nn.L1Loss(), my_optimizer)
>>> # Training mode
>>> trainer.train()
>>> for batch, (prediction, loss) in trainer(train_data_loader):
>>>     print(loss['training_loss'])
>>> # Validation mode
>>> trainer.eval()
>>> for batch, (prediction, loss) in trainer(valid_data_loader):
>>>     print(loss['validation_loss'])

Note

If the criterion returns a dict of (named) losses, then they are added together to backpropagate. The total is returned along with all the partial losses.

Vanilla GAN

class dlt.train.VanillaGANTrainer(generator, discriminator, g_optimizer, d_optimizer, d_iter=1)

Generative Adversarial Networks trainer.

Parameters:
  • generator (nn.Module) – The generator network.
  • discriminator (nn.Module) – The discriminator network.
  • g_optimizer (torch.optim.Optimizer) – Generator Optimizer.
  • d_optimizer (torch.optim.Optimizer) – Discriminator Optimizer.
  • d_iter (python:int, optional) – Number of discriminator steps per generator step (default 1).

Each iteration returns the mini-batch and a tuple containing:

  • The generator prediction.

  • A dictionary containing a d_loss (not when validating) and a g_loss dictionary (only if a generator step is performed):

    • d_loss contains: d_loss, real_loss, and fake_loss.
    • g_loss contains: g_loss.

Example

>>> trainer = dlt.train.VanillaGANTrainer(gen, disc, g_optim, d_optim)
>>> # Training mode
>>> trainer.train()
>>> for batch, (prediction, loss) in trainer(train_data_loader):
>>>     print(loss['d_loss']['d_loss'])

Warning

This trainer uses BCEWithLogitsLoss, which means that the discriminator must NOT have a sigmoid at the end.

WGAN-GP

class dlt.train.WGANGPTrainer(generator, discriminator, g_optimizer, d_optimizer, lambda_gp, d_iter=1)

Wasserstein GAN Trainer with gradient penalty.

Parameters:
  • generator (nn.Module) – The generator network.
  • discriminator (nn.Module) – The discriminator network.
  • g_optimizer (torch.optim.Optimizer) – Generator Optimizer.
  • d_optimizer (torch.optim.Optimizer) – Discriminator Optimizer.
  • lambda_gp (python:float) – Weight of gradient penalty.
  • d_iter (python:int, optional) – Number of discriminator steps per generator step (default 1).

Each iteration returns the mini-batch and a tuple containing:

  • The generator prediction.

  • A dictionary containing a d_loss (not when validating) and a g_loss dictionary (only if a generator step is performed):

    • d_loss contains: d_loss, w_loss, and gp.
    • g_loss contains: g_loss.

Example

>>> trainer = dlt.train.WGANGPTrainer(gen, disc, g_optim, d_optim, lambda_gp)
>>> # Training mode
>>> trainer.train()
>>> for batch, (prediction, loss) in trainer(train_data_loader):
>>>     print(loss['d_loss']['w_loss'])

WGAN-CT

class dlt.train.WGANCTTrainer(generator, discriminator, g_optimizer, d_optimizer, lambda_gp, m_ct, lambda_ct, d_iter=1)

Wasserstein GAN Trainer with gradient penalty and correction term.

From Improving the Improved Training of Wasserstein GANs: A Consistency Term and Its Dual Effect.

https://openreview.net/forum?id=SJx9GQb0-

Parameters:
  • generator (nn.Module) – The generator network.
  • discriminator (nn.Module) – The discriminator network.
  • g_optimizer (torch.optim.Optimizer) – Generator Optimizer.
  • d_optimizer (torch.optim.Optimizer) – Discriminator Optimizer.
  • lambda_gp (python:float) – Weight of gradient penalty.
  • m_ct (python:float) – Constant bound for consistency term.
  • lambda_ct (python:float) – Weight of consistency term.
  • d_iter (python:int, optional) – Number of discriminator steps per generator step (default 1).

Each iteration returns the mini-batch and a tuple containing:

  • The generator prediction.

  • A dictionary containing a d_loss (not when validating) and a g_loss dictionary (only if a generator step is performed):

    • d_loss contains: d_loss, w_loss, gp and ct.
    • g_loss contains: g_loss.

Warning

The discriminator forward function needs to be able to accept an optional bool argument correction_term. When set to true, the forward function must add dropout noise to the model and return a tuple containing the second to last output of the discriminator along with the final output.

Example

>>> trainer = dlt.train.WGANCTTrainer(gen, disc, g_optim, d_optim, lambda_gp, m_ct, lambda_ct)
>>> # Training mode
>>> trainer.train()
>>> for batch, (prediction, loss) in trainer(train_data_loader):
>>>     print(loss['d_loss']['w_loss'])

BEGAN

class dlt.train.BEGANTrainer(generator, discriminator, g_optimizer, d_optimizer, lambda_k, gamma, d_iter=1)

Boundary Equilibrium GAN trainer.

Parameters:
  • generator (nn.Module) – The generator network.
  • discriminator (nn.Module) – The discriminator network.
  • g_optimizer (torch.optim.Optimizer) – Generator Optimizer.
  • d_optimizer (torch.optim.Optimizer) – Discriminator Optimizer.
  • lambda_k (python:float) – Learning rate of k parameter.
  • gamma (python:float) – Diversity ratio.
  • d_iter (python:int) – Number of discriminator steps per generator step.

Each iteration returns the mini-batch and a tuple containing:

  • The generator prediction.

  • A dictionary containing a d_loss (not when validating) and a g_loss dictionary (only if a generator step is performed):

    • d_loss contains: d_loss, real_loss, fake_loss, k, balance, and measure.
    • g_loss contains: g_loss.

Example:

>>> trainer = dlt.train.BEGANTrainer(gen, disc, g_optim, d_optim, lambda_k, gamma)
>>> # Training mode
>>> trainer.train()
>>> for batch, (prediction, loss) in trainer(train_data_loader):
>>>     print(loss['d_loss']['measure'])

Fisher-GAN

class dlt.train.FisherGANTrainer(generator, discriminator, g_optimizer, d_optimizer, rho, d_iter=1)

Fisher GAN trainer.

Parameters:
  • generator (nn.Module) – The generator network.
  • discriminator (nn.Module) – The discriminator network.
  • g_optimizer (torch.optim.Optimizer) – Generator Optimizer.
  • d_optimizer (torch.optim.Optimizer) – Discriminator Optimizer.
  • rho (python:float) – Quadratic penalty weight.
  • d_iter (python:int, optional) – Number of discriminator steps per generator step (default 1).

Each iteration returns the mini-batch and a tuple containing:

  • The generator prediction.

  • A dictionary containing a d_loss (not when validating) and a g_loss dictionary (only if a generator step is performed):

    • d_loss contains: ipm_enum, ipm_denom, ipm_ratio, d_loss, constraint, epf, eqf, epf2, eqf2 and lagrange.
    • g_loss contains: g_loss.

Example

>>> trainer = dlt.train.FisherGANTrainer(gen, disc, g_optim, d_optim, rho)
>>> # Training mode
>>> trainer.train()
>>> for batch, (prediction, loss) in trainer(train_data_loader):
>>>     print(loss['d_loss']['constraint'])