Shortcuts

mmedit.models.losses.perceptual_loss

Module Contents

Classes

PerceptualVGG

VGG network used in calculating perceptual loss.

PerceptualLoss

Perceptual loss with commonly used style loss.

TransferalPerceptualLoss

Transferal perceptual loss.

class mmedit.models.losses.perceptual_loss.PerceptualVGG(layer_name_list: List[str], vgg_type: str = 'vgg19', use_input_norm: bool = True, pretrained: str = 'torchvision://vgg19')[源代码]

Bases: torch.nn.Module

VGG network used in calculating perceptual loss.

In this implementation, we allow users to choose whether use normalization in the input feature and the type of vgg network. Note that the pretrained path must fit the vgg type.

参数
  • layer_name_list (list[str]) – According to the name in this list, forward function will return the corresponding features. This list contains the name each layer in vgg.feature. An example of this list is [‘4’, ‘10’].

  • vgg_type (str) – Set the type of vgg network. Default: ‘vgg19’.

  • use_input_norm (bool) – If True, normalize the input image. Importantly, the input feature must in the range [0, 1]. Default: True.

  • pretrained (str) – Path for pretrained weights. Default: ‘torchvision://vgg19’

forward(x: torch.Tensor) torch.Tensor[源代码]

Forward function.

参数

x (Tensor) – Input tensor with shape (n, c, h, w).

返回

Forward results.

返回类型

Tensor

init_weights(model: torch.nn.Module, pretrained: str) None[源代码]

Init weights.

参数
  • model (nn.Module) – Models to be inited.

  • pretrained (str) – Path for pretrained weights.

class mmedit.models.losses.perceptual_loss.PerceptualLoss(layer_weights: dict, layer_weights_style: Optional[dict] = None, vgg_type: str = 'vgg19', use_input_norm: bool = True, perceptual_weight: float = 1.0, style_weight: float = 1.0, norm_img: bool = True, pretrained: str = 'torchvision://vgg19', criterion: str = 'l1')[源代码]

Bases: torch.nn.Module

Perceptual loss with commonly used style loss.

参数
  • layers_weights (dict) – The weight for each layer of vgg feature for perceptual loss. Here is an example: {‘4’: 1., ‘9’: 1., ‘18’: 1.}, which means the 5th, 10th and 18th feature layer will be extracted with weight 1.0 in calculating losses.

  • layers_weights_style (dict) – The weight for each layer of vgg feature for style loss. If set to ‘None’, the weights are set equal to the weights for perceptual loss. Default: None.

  • vgg_type (str) – The type of vgg network used as feature extractor. Default: ‘vgg19’.

  • use_input_norm (bool) – If True, normalize the input image in vgg. Default: True.

  • perceptual_weight (float) – If perceptual_weight > 0, the perceptual loss will be calculated and the loss will multiplied by the weight. Default: 1.0.

  • style_weight (float) – If style_weight > 0, the style loss will be calculated and the loss will multiplied by the weight. Default: 1.0.

  • norm_img (bool) – If True, the image will be normed to [0, 1]. Note that this is different from the use_input_norm which norm the input in in forward function of vgg according to the statistics of dataset. Importantly, the input image must be in range [-1, 1].

  • pretrained (str) – Path for pretrained weights. Default: ‘torchvision://vgg19’.

  • criterion (str) – Criterion type. Options are ‘l1’ and ‘mse’. Default: ‘l1’.

forward(x: torch.Tensor, gt: torch.Tensor) Tuple[torch.Tensor][源代码]

Forward function.

参数
  • x (Tensor) – Input tensor with shape (n, c, h, w).

  • gt (Tensor) – Ground-truth tensor with shape (n, c, h, w).

返回

Forward results.

返回类型

Tensor

_gram_mat(x: torch.Tensor) torch.Tensor[源代码]

Calculate Gram matrix.

参数

x (torch.Tensor) – Tensor with shape of (n, c, h, w).

返回

Gram matrix.

返回类型

torch.Tensor

class mmedit.models.losses.perceptual_loss.TransferalPerceptualLoss(loss_weight: float = 1.0, use_attention: bool = True, criterion: str = 'mse')[源代码]

Bases: torch.nn.Module

Transferal perceptual loss.

参数
  • loss_weight (float) – Loss weight. Default: 1.0.

  • use_attention (bool) – If True, use soft-attention tensor. Default: True

  • criterion (str) – Criterion type. Options are ‘l1’ and ‘mse’. Default: ‘mse’.

forward(maps: Tuple[torch.Tensor], soft_attention: torch.Tensor, textures: Tuple[torch.Tensor]) torch.Tensor[源代码]

Forward function.

参数
  • maps (Tuple[Tensor]) – Input tensors.

  • soft_attention (Tensor) – Soft-attention tensor.

  • textures (Tuple[Tensor]) – Ground-truth tensors.

返回

Forward results.

返回类型

Tensor

Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.