Shortcuts

mmedit.models.losses.loss_comps.gen_auxiliary_loss_comps

Module Contents

Classes

GeneratorPathRegularizerComps

Generator Path Regularizer.

class mmedit.models.losses.loss_comps.gen_auxiliary_loss_comps.GeneratorPathRegularizerComps(loss_weight: float = 1.0, pl_batch_shrink: int = 1, decay: float = 0.01, pl_batch_size: Optional[int] = None, sync_mean_buffer: bool = False, interval: int = 1, data_info: Optional[dict] = None, use_apex_amp: bool = False, loss_name: str = 'loss_path_regular')[源代码]

Bases: torch.nn.Module

Generator Path Regularizer.

Path regularization is proposed in StyelGAN2, which can help the improve the continuity of the latent space. More details can be found in: Analyzing and Improving the Image Quality of StyleGAN, CVPR2020.

Users can achieve lazy regularization by setting interval arguments here.

Note for the design of ``data_info``: In MMEditing, almost all of loss modules contain the argument data_info, which can be used for constructing the link between the input items (needed in loss calculation) and the data from the generative model. For example, in the training of GAN model, we will collect all of important data/modules into a dictionary:

Code from StaticUnconditionalGAN, train_step
1data_dict_ = dict(
2    gen=self.generator,
3    disc=self.discriminator,
4    fake_imgs=fake_imgs,
5    disc_pred_fake_g=disc_pred_fake_g,
6    iteration=curr_iter,
7    batch_size=batch_size)

But in this loss, we will need to provide generator and num_batches as input. Thus an example of the data_info is:

1data_info = dict(
2    generator='gen',
3    num_batches='batch_size')

Then, the module will automatically construct this mapping from the input data dictionary.

参数
  • loss_weight (float, optional) – Weight of this loss item. Defaults to 1..

  • pl_batch_shrink (int, optional) – The factor of shrinking the batch size for saving GPU memory. Defaults to 1.

  • decay (float, optional) – Decay for moving average of mean path length. Defaults to 0.01.

  • pl_batch_size (int | None, optional) – The batch size in calculating generator path. Once this argument is set, the num_batches will be overridden with this argument and won’t be affectted by pl_batch_shrink. Defaults to None.

  • sync_mean_buffer (bool, optional) – Whether to sync mean path length across all of GPUs. Defaults to False.

  • interval (int, optional) – The interval of calculating this loss. This argument is used to support lazy regularization. Defaults to 1.

  • data_info (dict, optional) – Dictionary contains the mapping between loss input args and data dictionary. If None, this module will directly pass the input data to the loss function. Defaults to None.

  • loss_name (str, optional) – Name of the loss item. If you want this loss item to be included into the backward graph, loss_ must be the prefix of the name. Defaults to ‘loss_path_regular’.

forward(*args, **kwargs) torch.Tensor[源代码]

Forward function.

If self.data_info is not None, a dictionary containing all of the data and necessary modules should be passed into this function. If this dictionary is given as a non-keyword argument, it should be offered as the first argument. If you are using keyword argument, please name it as outputs_dict.

If self.data_info is None, the input argument or key-word argument will be directly passed to loss function, gen_path_regularizer.

loss_name() str[源代码]

Loss Name.

This function must be implemented and will return the name of this loss function. This name will be used to combine different loss items by simple sum operation. In addition, if you want this loss item to be included into the backward graph, loss_ must be the prefix of the name.

返回

The name of this loss item.

返回类型

str

Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.