Shortcuts

mmedit.models.losses.gan_loss

Module Contents

Classes

GANLoss

Define GAN loss.

GaussianBlur

A Gaussian filter which blurs a given tensor with a two-dimensional

GradientPenaltyLoss

Gradient penalty loss for wgan-gp.

DiscShiftLoss

Disc shift loss.

Functions

gradient_penalty_loss(→ torch.Tensor)

Calculate gradient penalty for wgan-gp.

disc_shift_loss(→ torch.Tensor)

Disc Shift loss.

r1_gradient_penalty_loss(→ torch.Tensor)

Calculate R1 gradient penalty for WGAN-GP.

gen_path_regularizer(→ Tuple[torch.Tensor])

Generator Path Regularization.

class mmedit.models.losses.gan_loss.GANLoss(gan_type: str, real_label_val: float = 1.0, fake_label_val: float = 0.0, loss_weight: float = 1.0)[源代码]

Bases: torch.nn.Module

Define GAN loss.

参数
  • gan_type (str) – Support ‘vanilla’, ‘lsgan’, ‘wgan’, ‘hinge’.

  • real_label_val (float) – The value for real label. Default: 1.0.

  • fake_label_val (float) – The value for fake label. Default: 0.0.

  • loss_weight (float) – Loss weight. Default: 1.0. Note that loss_weight is only for generators; and it is always 1.0 for discriminators.

_wgan_loss(input: torch.Tensor, target: bool) torch.Tensor[源代码]

wgan loss.

参数
  • input (Tensor) – Input tensor.

  • target (bool) – Target label.

返回

wgan loss.

返回类型

Tensor

get_target_label(input: torch.Tensor, target_is_real: bool) Union[bool, torch.Tensor][源代码]

Get target label.

参数
  • input (Tensor) – Input tensor.

  • target_is_real (bool) – Whether the target is real or fake.

返回

Target tensor. Return bool for wgan, otherwise,

return Tensor.

返回类型

(bool | Tensor)

forward(input: torch.Tensor, target_is_real: bool, is_disc: bool = False, mask: Optional[torch.Tensor] = None) torch.Tensor[源代码]
参数
  • input (Tensor) – The input for the loss module, i.e., the network prediction.

  • target_is_real (bool) – Whether the target is real or fake.

  • is_disc (bool) – Whether the loss for discriminators or not. Default: False.

  • mask (Tensor) – The mask tensor. Default: None.

返回

GAN loss value.

返回类型

Tensor

class mmedit.models.losses.gan_loss.GaussianBlur(kernel_size: Tuple[int, int] = (71, 71), sigma: Tuple[float, float] = (10.0, 10.0))[源代码]

Bases: torch.nn.Module

A Gaussian filter which blurs a given tensor with a two-dimensional gaussian kernel by convolving it along each channel. Batch operation is supported.

This function is modified from kornia.filters.gaussian: <https://kornia.readthedocs.io/en/latest/_modules/kornia/filters/gaussian.html>.

参数
  • kernel_size (tuple[int]) – The size of the kernel. Default: (71, 71).

  • sigma (tuple[float]) – The standard deviation of the kernel.

  • Default (10.0, 10.0) –

返回

The Gaussian-blurred tensor.

返回类型

Tensor

Shape:
  • input: Tensor with shape of (n, c, h, w)

  • output: Tensor with shape of (n, c, h, w)

static compute_zero_padding(kernel_size: Tuple[int, int]) tuple[源代码]

Compute zero padding tuple.

参数

kernel_size (tuple[int]) – The size of the kernel.

返回

Padding of height and weight.

返回类型

tuple

get_2d_gaussian_kernel(kernel_size: Tuple[int, int], sigma: Tuple[float, float]) torch.Tensor[源代码]

Get the two-dimensional Gaussian filter matrix coefficients.

参数
  • kernel_size (tuple[int]) – Kernel filter size in the x and y direction. The kernel sizes should be odd and positive.

  • sigma (tuple[int]) – Gaussian standard deviation in the x and y direction.

返回

A 2D torch tensor with gaussian filter

matrix coefficients.

返回类型

kernel_2d (Tensor)

get_1d_gaussian_kernel(kernel_size: int, sigma: float) torch.Tensor[源代码]

Get the Gaussian filter coefficients in one dimension (x or y direction).

参数
  • kernel_size (int) – Kernel filter size in x or y direction. Should be odd and positive.

  • sigma (float) – Gaussian standard deviation in x or y direction.

返回

A 1D torch tensor with gaussian filter

coefficients in x or y direction.

返回类型

kernel_1d (Tensor)

gaussian(kernel_size: int, sigma: float) torch.Tensor[源代码]

Gaussian function.

参数
  • kernel_size (int) – Kernel filter size in x or y direction. Should be odd and positive.

  • sigma (float) – Gaussian standard deviation in x or y direction.

返回

A 1D torch tensor with gaussian filter

coefficients in x or y direction.

返回类型

Tensor

forward(x: torch.Tensor) torch.Tensor[源代码]

Forward function.

参数

x (Tensor) – Tensor with shape (n, c, h, w)

返回

The Gaussian-blurred tensor.

返回类型

Tensor

mmedit.models.losses.gan_loss.gradient_penalty_loss(discriminator: torch.nn.Module, real_data: torch.Tensor, fake_data: torch.Tensor, mask: Optional[torch.Tensor] = None, norm_mode: str = 'pixel') torch.Tensor[源代码]

Calculate gradient penalty for wgan-gp.

参数
  • discriminator (nn.Module) – Network for the discriminator.

  • real_data (Tensor) – Real input data.

  • fake_data (Tensor) – Fake input data.

  • mask (Tensor) – Masks for inpainting. Default: None.

返回

A tensor for gradient penalty.

返回类型

Tensor

class mmedit.models.losses.gan_loss.GradientPenaltyLoss(loss_weight: float = 1.0)[源代码]

Bases: torch.nn.Module

Gradient penalty loss for wgan-gp.

参数

loss_weight (float) – Loss weight. Default: 1.0.

forward(discriminator: torch.nn.Module, real_data: torch.Tensor, fake_data: torch.Tensor, mask: Optional[torch.Tensor] = None) torch.Tensor[源代码]

Forward function.

参数
  • discriminator (nn.Module) – Network for the discriminator.

  • real_data (Tensor) – Real input data.

  • fake_data (Tensor) – Fake input data.

  • mask (Tensor) – Masks for inpainting. Default: None.

返回

Loss.

返回类型

Tensor

mmedit.models.losses.gan_loss.disc_shift_loss(pred: torch.Tensor) torch.Tensor[源代码]

Disc Shift loss.

This loss is proposed in PGGAN as an auxiliary loss for discriminator.

参数

pred (Tensor) – Input tensor.

返回

loss tensor.

返回类型

torch.Tensor

class mmedit.models.losses.gan_loss.DiscShiftLoss(loss_weight: float = 0.1)[源代码]

Bases: torch.nn.Module

Disc shift loss.

参数

loss_weight (float, optional) – Loss weight. Defaults to 1.0.

forward(x: torch.Tensor) torch.Tensor[源代码]

Forward function.

参数

x (Tensor) – Tensor with shape (n, c, h, w)

返回

Loss.

返回类型

Tensor

mmedit.models.losses.gan_loss.r1_gradient_penalty_loss(discriminator: torch.nn.Module, real_data: torch.Tensor, mask: Optional[torch.Tensor] = None, norm_mode: str = 'pixel', loss_scaler: Optional[torch.cuda.amp.grad_scaler.GradScaler] = None, use_apex_amp: bool = False) torch.Tensor[源代码]

Calculate R1 gradient penalty for WGAN-GP.

R1 regularizer comes from: “Which Training Methods for GANs do actually Converge?” ICML’2018

Different from original gradient penalty, this regularizer only penalized gradient w.r.t. real data.

参数
  • discriminator (nn.Module) – Network for the discriminator.

  • real_data (Tensor) – Real input data.

  • mask (Tensor) – Masks for inpainting. Default: None.

  • norm_mode (str) – This argument decides along which dimension the norm of the gradients will be calculated. Currently, we support [“pixel” , “HWC”]. Defaults to “pixel”.

返回

A tensor for gradient penalty.

返回类型

Tensor

mmedit.models.losses.gan_loss.gen_path_regularizer(generator: torch.nn.Module, num_batches: int, mean_path_length: torch.Tensor, pl_batch_shrink: int = 1, decay: float = 0.01, weight: float = 1.0, pl_batch_size: Optional[int] = None, sync_mean_buffer: bool = False, loss_scaler: Optional[torch.cuda.amp.grad_scaler.GradScaler] = None, use_apex_amp: bool = False) Tuple[torch.Tensor][源代码]

Generator Path Regularization.

Path regularization is proposed in StyelGAN2, which can help the improve the continuity of the latent space. More details can be found in: Analyzing and Improving the Image Quality of StyleGAN, CVPR2020.

参数
  • generator (nn.Module) – The generator module. Note that this loss requires that the generator contains return_latents interface, with which we can get the latent code of the current sample.

  • num_batches (int) – The number of samples used in calculating this loss.

  • mean_path_length (Tensor) – The mean path length, calculated by moving average.

  • pl_batch_shrink (int, optional) – The factor of shrinking the batch size for saving GPU memory. Defaults to 1.

  • decay (float, optional) – Decay for moving average of mean path length. Defaults to 0.01.

  • weight (float, optional) – Weight of this loss item. Defaults to 1..

  • pl_batch_size (int | None, optional) – The batch size in calculating generator path. Once this argument is set, the num_batches will be overridden with this argument and won’t be affectted by pl_batch_shrink. Defaults to None.

  • sync_mean_buffer (bool, optional) – Whether to sync mean path length across all of GPUs. Defaults to False.

返回

The penalty loss, detached mean path tensor, and current path length.

返回类型

tuple[Tensor]

Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.