Source code for mmedit.models.losses.loss_comps.gan_loss_comps

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from mmedit.registry import MODELS

[docs]class GANLossComps(nn.Module): """Define GAN loss. Args: gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge', 'wgan-logistic-ns'. real_label_val (float): The value for real label. Default: 1.0. fake_label_val (float): The value for fake label. Default: 0.0. loss_weight (float): Loss weight. Default: 1.0. Note that loss_weight is only for generators; and it is always 1.0 for discriminators. """ def __init__(self, gan_type: str, real_label_val: float = 1.0, fake_label_val: float = 0.0, loss_weight: float = 1.0) -> None: super().__init__() self.gan_type = gan_type self.loss_weight = loss_weight self.real_label_val = real_label_val self.fake_label_val = fake_label_val if self.gan_type == 'vanilla': self.loss = nn.BCEWithLogitsLoss() elif self.gan_type == 'lsgan': self.loss = nn.MSELoss() elif self.gan_type == 'wgan': self.loss = self._wgan_loss elif self.gan_type == 'wgan-logistic-ns': self.loss = self._wgan_logistic_ns_loss elif self.gan_type == 'hinge': self.loss = nn.ReLU() else: raise NotImplementedError( f'GAN type {self.gan_type} is not implemented.')
[docs] def _wgan_loss(self, input: torch.Tensor, target: bool) -> torch.Tensor: """wgan loss. Args: input (Tensor): Input tensor. target (bool): Target label. Returns: Tensor: wgan loss. """ return -input.mean() if target else input.mean()
[docs] def _wgan_logistic_ns_loss(self, input: torch.Tensor, target: bool) -> torch.Tensor: """WGAN loss in logistically non-saturating mode. This loss is widely used in StyleGANv2. Args: input (Tensor): Input tensor. target (bool): Target label. Returns: Tensor: wgan loss. """ return F.softplus(-input).mean() if target else F.softplus( input).mean()
[docs] def get_target_label(self, input: torch.Tensor, target_is_real: bool) -> Union[bool, torch.Tensor]: """Get target label. Args: input (Tensor): Input tensor. target_is_real (bool): Whether the target is real or fake. Returns: (bool | Tensor): Target tensor. Return bool for wgan, otherwise, \ return Tensor. """ if self.gan_type in ['wgan', 'wgan-logistic-ns']: return target_is_real target_val = ( self.real_label_val if target_is_real else self.fake_label_val) return input.new_ones(input.size()) * target_val
[docs] def forward(self, input: torch.Tensor, target_is_real: bool, is_disc: bool = False) -> torch.Tensor: """ Args: input (Tensor): The input for the loss module, i.e., the network prediction. target_is_real (bool): Whether the targe is real or fake. is_disc (bool): Whether the loss for discriminators or not. Default: False. Returns: Tensor: GAN loss value. """ target_label = self.get_target_label(input, target_is_real) if self.gan_type == 'hinge': if is_disc: # for discriminators in hinge-gan input = -input if target_is_real else input loss = self.loss(1 + input).mean() else: # for generators in hinge-gan loss = -input.mean() else: # other gan types loss = self.loss(input, target_label) # loss_weight is always 1.0 for discriminators return loss if is_disc else loss * self.loss_weight
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.