Shortcuts

mmedit.models.losses.loss_comps.gan_loss_comps

Module Contents

Classes

GANLossComps

Define GAN loss.

class mmedit.models.losses.loss_comps.gan_loss_comps.GANLossComps(gan_type: str, real_label_val: float = 1.0, fake_label_val: float = 0.0, loss_weight: float = 1.0)[源代码]

Bases: torch.nn.Module

Define GAN loss.

参数
  • gan_type (str) – Support ‘vanilla’, ‘lsgan’, ‘wgan’, ‘hinge’, ‘wgan-logistic-ns’.

  • real_label_val (float) – The value for real label. Default: 1.0.

  • fake_label_val (float) – The value for fake label. Default: 0.0.

  • loss_weight (float) – Loss weight. Default: 1.0. Note that loss_weight is only for generators; and it is always 1.0 for discriminators.

_wgan_loss(input: torch.Tensor, target: bool) torch.Tensor[源代码]

wgan loss.

参数
  • input (Tensor) – Input tensor.

  • target (bool) – Target label.

返回

wgan loss.

返回类型

Tensor

_wgan_logistic_ns_loss(input: torch.Tensor, target: bool) torch.Tensor[源代码]

WGAN loss in logistically non-saturating mode.

This loss is widely used in StyleGANv2.

参数
  • input (Tensor) – Input tensor.

  • target (bool) – Target label.

返回

wgan loss.

返回类型

Tensor

get_target_label(input: torch.Tensor, target_is_real: bool) Union[bool, torch.Tensor][源代码]

Get target label.

参数
  • input (Tensor) – Input tensor.

  • target_is_real (bool) – Whether the target is real or fake.

返回

Target tensor. Return bool for wgan, otherwise, return Tensor.

返回类型

(bool | Tensor)

forward(input: torch.Tensor, target_is_real: bool, is_disc: bool = False) torch.Tensor[源代码]
参数
  • input (Tensor) – The input for the loss module, i.e., the network prediction.

  • target_is_real (bool) – Whether the targe is real or fake.

  • is_disc (bool) – Whether the loss for discriminators or not. Default: False.

返回

GAN loss value.

返回类型

Tensor

Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.