Shortcuts

mmedit.models.editors.mspie.mspie_stylegan2

Module Contents

Classes

MSPIEStyleGAN2

MS-PIE StyleGAN2.

Attributes

ModelType

TrainInput

mmedit.models.editors.mspie.mspie_stylegan2.ModelType[source]
mmedit.models.editors.mspie.mspie_stylegan2.TrainInput[source]
class mmedit.models.editors.mspie.mspie_stylegan2.MSPIEStyleGAN2(*args, train_settings=dict(), **kwargs)[source]

Bases: mmedit.models.editors.stylegan2.StyleGAN2

MS-PIE StyleGAN2.

In this GAN, we adopt the MS-PIE training schedule so that multi-scale images can be generated with a single generator. Details can be found in: Positional Encoding as Spatial Inductive Bias in GANs, CVPR2021.

Parameters

train_settings (dict) – Config for training settings. Defaults to dict().

train_step(data: dict, optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor][source]

Train GAN model. In the training of GAN models, generator and discriminator are updated alternatively. In MMGeneration’s design, self.train_step is called with data input. Therefore we always update discriminator, whose updating is relay on real data, and then determine if the generator needs to be updated based on the current number of iterations. More details about whether to update generator can be found in should_gen_update().

Parameters
  • data (dict) – Data sampled from dataloader.

  • optim_wrapper (OptimWrapperDict) – OptimWrapperDict instance contains OptimWrapper of generator and discriminator.

Returns

A dict of tensor for logging.

Return type

Dict[str, torch.Tensor]

train_generator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][source]

Train generator.

Parameters
  • inputs (TrainInput) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

train_discriminator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][source]

Train discriminator.

Parameters
  • inputs (TrainInput) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

Read the Docs v: latest
Versions
master
latest
stable
zyh-re-docs
zyh-doc-notfound-extend
zyh-api-rendering
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.