Shortcuts

mmedit.models.editors.mspie.mspie_stylegan2

Module Contents

Classes

MSPIEStyleGAN2

MS-PIE StyleGAN2.

Attributes

ModelType

TrainInput

mmedit.models.editors.mspie.mspie_stylegan2.ModelType[源代码]
mmedit.models.editors.mspie.mspie_stylegan2.TrainInput[源代码]
class mmedit.models.editors.mspie.mspie_stylegan2.MSPIEStyleGAN2(*args, train_settings=dict(), **kwargs)[源代码]

Bases: mmedit.models.editors.stylegan2.StyleGAN2

MS-PIE StyleGAN2.

In this GAN, we adopt the MS-PIE training schedule so that multi-scale images can be generated with a single generator. Details can be found in: Positional Encoding as Spatial Inductive Bias in GANs, CVPR2021.

参数

train_settings (dict) – Config for training settings. Defaults to dict().

train_step(data: dict, optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor][源代码]

Train GAN model. In the training of GAN models, generator and discriminator are updated alternatively. In MMGeneration’s design, self.train_step is called with data input. Therefore we always update discriminator, whose updating is relay on real data, and then determine if the generator needs to be updated based on the current number of iterations. More details about whether to update generator can be found in should_gen_update().

参数
  • data (dict) – Data sampled from dataloader.

  • optim_wrapper (OptimWrapperDict) – OptimWrapperDict instance contains OptimWrapper of generator and discriminator.

返回

A dict of tensor for logging.

返回类型

Dict[str, torch.Tensor]

train_generator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][源代码]

Train generator.

参数
  • inputs (TrainInput) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

返回

A dict of tensor for logging.

返回类型

Dict[str, Tensor]

train_discriminator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][源代码]

Train discriminator.

参数
  • inputs (TrainInput) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

返回

A dict of tensor for logging.

返回类型

Dict[str, Tensor]

Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.