Shortcuts

mmedit.models.editors.wgan_gp

Package Contents

Classes

WGANGPDiscriminator

Discriminator for WGANGP.

WGANGPGenerator

Generator for WGANGP.

WGANGP

Impelmentation of Improved Training of Wasserstein GANs.

class mmedit.models.editors.wgan_gp.WGANGPDiscriminator(in_channel, in_scale, conv_module_cfg=None)[source]

Bases: torch.nn.Module

Discriminator for WGANGP.

Implementation Details for WGANGP discriminator the same as training configuration (a) described in PGGAN paper: PROGRESSIVE GROWING OF GANS FOR IMPROVED QUALITY, STABILITY, AND VARIATION https://research.nvidia.com/sites/default/files/pubs/2017-10_Progressive-Growing-of/karras2018iclr-paper.pdf # noqa

  1. Adopt convolution architecture specified in appendix A.2;

  2. Add layer normalization to all conv3x3 and conv4x4 layers;

  3. Use LeakyReLU in the discriminator except for the final output layer;

  4. Initialize all weights using He’s initializer.

Parameters
  • in_channel (int) – The channel number of the input image.

  • in_scale (int) – The scale of the input image.

  • conv_module_cfg (dict, optional) – Config for the convolution module used in this discriminator. Defaults to None.

_default_channels_per_scale
_default_conv_module_cfg
_default_upsample_cfg
forward(x)[source]

Forward function.

Parameters

x (torch.Tensor) – Fake or real image tensor.

Returns

Prediction for the reality of the input image.

Return type

torch.Tensor

class mmedit.models.editors.wgan_gp.WGANGPGenerator(noise_size, out_scale, conv_module_cfg=None, upsample_cfg=None)[source]

Bases: torch.nn.Module

Generator for WGANGP.

Implementation Details for WGANGP generator the same as training configuration (a) described in PGGAN paper: PROGRESSIVE GROWING OF GANS FOR IMPROVED QUALITY, STABILITY, AND VARIATION https://research.nvidia.com/sites/default/files/pubs/2017-10_Progressive-Growing-of/karras2018iclr-paper.pdf # noqa

  1. Adopt convolution architecture specified in appendix A.2;

  2. Use batchnorm in the generator except for the final output layer;

  3. Use ReLU in the generator except for the final output layer;

  4. Use Tanh in the last layer;

  5. Initialize all weights using He’s initializer.

Parameters
  • noise_size (int) – Size of the input noise vector.

  • out_scale (int) – Output scale for the generated image.

  • conv_module_cfg (dict, optional) – Config for the convolution module used in this generator. Defaults to None.

  • upsample_cfg (dict, optional) – Config for the upsampling operation. Defaults to None.

_default_channels_per_scale
_default_conv_module_cfg
_default_upsample_cfg
forward(noise, num_batches=0, return_noise=False)[source]

Forward function.

Parameters
  • noise (torch.Tensor | callable | None) – You can directly give a batch of noise through a torch.Tensor or offer a callable function to sample a batch of noise data. Otherwise, the None indicates to use the default noise sampler.

  • num_batches (int, optional) – The number of batch size. Defaults to 0.

  • return_noise (bool, optional) – If True, noise_batch will be returned in a dict with fake_img. Defaults to False.

Returns

If not return_noise, only the output image

will be returned. Otherwise, a dict contains fake_img and noise_batch will be returned.

Return type

torch.Tensor | dict

class mmedit.models.editors.wgan_gp.WGANGP(*args, **kwargs)[source]

Bases: mmedit.models.base_models.BaseGAN

Impelmentation of Improved Training of Wasserstein GANs.

Paper link: https://arxiv.org/pdf/1704.00028

Detailed architecture can be found in WGANGPGenerator and WGANGPDiscriminator

disc_loss(real_data: torch.Tensor, fake_data: torch.Tensor, disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor) Tuple[source]

Get disc loss. WGAN-GP use the wgan loss and gradient penalty to train the discriminator.

Parameters
  • real_data (Tensor) – Real input data.

  • fake_data (Tensor) – Fake input data.

  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • disc_pred_real (Tensor) – Discriminator’s prediction of the real images.

Returns

Loss value and a dict of log variables.

Return type

tuple[Tensor, dict]

gen_loss(disc_pred_fake: torch.Tensor) Tuple[source]

Get gen loss. DCGAN use the wgan loss to train the generator.

Parameters

disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

Returns

Loss value and a dict of log variables.

Return type

tuple[Tensor, dict]

train_discriminator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][source]

Train discriminator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

train_generator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][source]

Train generator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

Read the Docs v: latest
Versions
master
latest
stable
zyh-re-docs
zyh-doc-notfound-extend
zyh-api-rendering
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.