Shortcuts

mmedit.models.editors.pggan.pggan

Module Contents

Classes

ProgressiveGrowingGAN

Progressive Growing Unconditional GAN.

Attributes

ModelType

TrainInput

mmedit.models.editors.pggan.pggan.ModelType[source]
mmedit.models.editors.pggan.pggan.TrainInput[source]
class mmedit.models.editors.pggan.pggan.ProgressiveGrowingGAN(generator, discriminator, data_preprocessor, nkimgs_per_scale, noise_size=None, interp_real=None, transition_kimgs: int = 600, prev_stage: int = 0, ema_config: Optional[Dict] = None)[source]

Bases: mmedit.models.base_models.BaseGAN

Progressive Growing Unconditional GAN.

In this GAN model, we implement progressive growing training schedule, which is proposed in Progressive Growing of GANs for improved Quality, Stability and Variation, ICLR 2018.

We highly recommend to use GrowScaleImgDataset for saving computational load in data pre-processing.

Notes for using PGGAN:

  1. In official implementation, Tero uses gradient penalty with norm_mode="HWC"

  2. We do not implement minibatch_repeats where has been used in official Tensorflow implementation.

Notes for resuming progressive growing GANs: Users should specify the prev_stage in train_cfg. Otherwise, the model is possible to reset the optimizer status, which will bring inferior performance. For example, if your model is resumed from the 256 stage, you should set train_cfg=dict(prev_stage=256).

Parameters
  • generator (dict) – Config for generator.

  • discriminator (dict) – Config for discriminator.

forward(inputs: mmedit.utils.typing.ForwardInputs, data_samples: Optional[list] = None, mode: Optional[str] = None) mmedit.utils.typing.SampleList[source]

Sample images from noises by using the generator.

Parameters
  • batch_inputs (ForwardInputs) – Dict containing the necessary information (e.g. noise, num_batches, mode) to generate image.

  • data_samples (Optional[list]) – Data samples collated by data_preprocessor. Defaults to None.

  • mode (Optional[str]) – mode is not used in ProgressiveGrowingGAN. Defaults to None.

Returns

A list of EditDataSample contain generated results.

Return type

SampleList

train_discriminator(inputs: torch.Tensor, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][source]

Train discriminator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

disc_loss(disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor, fake_data: torch.Tensor, real_data: torch.Tensor) Tuple[torch.Tensor, dict][source]

Get disc loss. PGGAN use WGAN-GP’s loss and discriminator shift loss to train the discriminator.

Parameters
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • disc_pred_real (Tensor) – Discriminator’s prediction of the real images.

  • fake_data (Tensor) – Generated images, used to calculate gradient penalty.

  • real_data (Tensor) – Real images, used to calculate gradient penalty.

Returns

Loss value and a dict of log variables.

Return type

Tuple[Tensor, dict]

train_generator(inputs: torch.Tensor, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][source]

Train generator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

gen_loss(disc_pred_fake: torch.Tensor) Tuple[torch.Tensor, dict][source]

Generator loss for PGGAN. PGGAN use WGAN’s loss to train the generator.

Parameters
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • recon_imgs (Tensor) – Reconstructive images.

Returns

Loss value and a dict of log variables.

Return type

Tuple[Tensor, dict]

train_step(data: dict, optim_wrapper: mmengine.optim.OptimWrapperDict)[source]

Train step function.

This function implements the standard training iteration for asynchronous adversarial training. Namely, in each iteration, we first update discriminator and then compute loss for generator with the newly updated discriminator.

As for distributed training, we use the reducer from ddp to synchronize the necessary params in current computational graph.

Parameters
  • data_batch (dict) – Input data from dataloader.

  • optimizer (dict) – Dict contains optimizer for generator and discriminator.

  • ddp_reducer (Reducer | None, optional) – Reducer from ddp. It is used to prepare for backward() in ddp. Defaults to None.

  • running_status (dict | None, optional) – Contains necessary basic information for training, e.g., iteration number. Defaults to None.

Returns

Contains ‘log_vars’, ‘num_samples’, and ‘results’.

Return type

dict

Read the Docs v: latest
Versions
master
latest
stable
zyh-re-docs
zyh-doc-notfound-extend
zyh-api-rendering
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.