Shortcuts

mmedit.models.editors.lsgan.lsgan

Module Contents

Classes

LSGAN

Impelmentation of Least Squares Generative Adversarial Networks.

class mmedit.models.editors.lsgan.lsgan.LSGAN(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, noise_size: Optional[int] = None, ema_config: Optional[Dict] = None, loss_config: Optional[Dict] = None)[源代码]

Bases: mmedit.models.base_models.BaseGAN

Impelmentation of Least Squares Generative Adversarial Networks.

Paper link: https://arxiv.org/pdf/1611.04076.pdf

Detailed architecture can be found in LSGANGenerator and LSGANDiscriminator

disc_loss(disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor) Tuple[源代码]

Get disc loss. LSGAN use the least squares loss to train the discriminator.

\[L_{D}=\left(D\left(X_{\text {data }}\right)-1\right)^{2} +(D(G(z)))^{2}\]
参数
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • disc_pred_real (Tensor) – Discriminator’s prediction of the real images.

返回

Loss value and a dict of log variables.

返回类型

tuple[Tensor, dict]

gen_loss(disc_pred_fake: torch.Tensor) Tuple[源代码]

Get gen loss. LSGAN use the least squares loss to train the generator.

\[L_{G}=(D(G(z))-1)^{2}\]
参数

disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

返回

Loss value and a dict of log variables.

返回类型

tuple[Tensor, dict]

train_discriminator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][源代码]

Train discriminator.

参数
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

返回

A dict of tensor for logging.

返回类型

Dict[str, Tensor]

train_generator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][源代码]

Train generator.

参数
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

返回

A dict of tensor for logging.

返回类型

Dict[str, Tensor]

Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.