Shortcuts

mmedit.models.editors.pggan.pggan

Module Contents

Classes

ProgressiveGrowingGAN

Progressive Growing Unconditional GAN.

Attributes

ModelType

TrainInput

mmedit.models.editors.pggan.pggan.ModelType[源代码]
mmedit.models.editors.pggan.pggan.TrainInput[源代码]
class mmedit.models.editors.pggan.pggan.ProgressiveGrowingGAN(generator, discriminator, data_preprocessor, nkimgs_per_scale, noise_size=None, interp_real=None, transition_kimgs: int = 600, prev_stage: int = 0, ema_config: Optional[Dict] = None)[源代码]

Bases: mmedit.models.base_models.BaseGAN

Progressive Growing Unconditional GAN.

In this GAN model, we implement progressive growing training schedule, which is proposed in Progressive Growing of GANs for improved Quality, Stability and Variation, ICLR 2018.

We highly recommend to use GrowScaleImgDataset for saving computational load in data pre-processing.

Notes for using PGGAN:

  1. In official implementation, Tero uses gradient penalty with norm_mode="HWC"

  2. We do not implement minibatch_repeats where has been used in official Tensorflow implementation.

Notes for resuming progressive growing GANs: Users should specify the prev_stage in train_cfg. Otherwise, the model is possible to reset the optimizer status, which will bring inferior performance. For example, if your model is resumed from the 256 stage, you should set train_cfg=dict(prev_stage=256).

参数
  • generator (dict) – Config for generator.

  • discriminator (dict) – Config for discriminator.

forward(inputs: mmedit.utils.typing.ForwardInputs, data_samples: Optional[list] = None, mode: Optional[str] = None) mmedit.utils.typing.SampleList[源代码]

Sample images from noises by using the generator.

参数
  • batch_inputs (ForwardInputs) – Dict containing the necessary information (e.g. noise, num_batches, mode) to generate image.

  • data_samples (Optional[list]) – Data samples collated by data_preprocessor. Defaults to None.

  • mode (Optional[str]) – mode is not used in ProgressiveGrowingGAN. Defaults to None.

返回

A list of EditDataSample contain generated results.

返回类型

SampleList

train_discriminator(inputs: torch.Tensor, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][源代码]

Train discriminator.

参数
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

返回

A dict of tensor for logging.

返回类型

Dict[str, Tensor]

disc_loss(disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor, fake_data: torch.Tensor, real_data: torch.Tensor) Tuple[torch.Tensor, dict][源代码]

Get disc loss. PGGAN use WGAN-GP’s loss and discriminator shift loss to train the discriminator.

参数
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • disc_pred_real (Tensor) – Discriminator’s prediction of the real images.

  • fake_data (Tensor) – Generated images, used to calculate gradient penalty.

  • real_data (Tensor) – Real images, used to calculate gradient penalty.

返回

Loss value and a dict of log variables.

返回类型

Tuple[Tensor, dict]

train_generator(inputs: torch.Tensor, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][源代码]

Train generator.

参数
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

返回

A dict of tensor for logging.

返回类型

Dict[str, Tensor]

gen_loss(disc_pred_fake: torch.Tensor) Tuple[torch.Tensor, dict][源代码]

Generator loss for PGGAN. PGGAN use WGAN’s loss to train the generator.

参数
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • recon_imgs (Tensor) – Reconstructive images.

返回

Loss value and a dict of log variables.

返回类型

Tuple[Tensor, dict]

train_step(data: dict, optim_wrapper: mmengine.optim.OptimWrapperDict)[源代码]

Train step function.

This function implements the standard training iteration for asynchronous adversarial training. Namely, in each iteration, we first update discriminator and then compute loss for generator with the newly updated discriminator.

As for distributed training, we use the reducer from ddp to synchronize the necessary params in current computational graph.

参数
  • data_batch (dict) – Input data from dataloader.

  • optimizer (dict) – Dict contains optimizer for generator and discriminator.

  • ddp_reducer (Reducer | None, optional) – Reducer from ddp. It is used to prepare for backward() in ddp. Defaults to None.

  • running_status (dict | None, optional) – Contains necessary basic information for training, e.g., iteration number. Defaults to None.

返回

Contains ‘log_vars’, ‘num_samples’, and ‘results’.

返回类型

dict

Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.