Shortcuts

mmedit.models.losses.loss_comps.disc_auxiliary_loss_comps

Module Contents

Classes

DiscShiftLossComps

Disc Shift Loss.

GradientPenaltyLossComps

Gradient Penalty for WGAN-GP.

R1GradientPenaltyComps

R1 gradient penalty for WGAN-GP.

class mmedit.models.losses.loss_comps.disc_auxiliary_loss_comps.DiscShiftLossComps(loss_weight: float = 1.0, data_info: Optional[dict] = None, loss_name: str = 'loss_disc_shift')[source]

Bases: torch.nn.Module

Disc Shift Loss.

This loss is proposed in PGGAN as an auxiliary loss for discriminator.

Note for the design of ``data_info``: In MMEditing, almost all of loss modules contain the argument data_info, which can be used for constructing the link between the input items (needed in loss calculation) and the data from the generative model. For example, in the training of GAN model, we will collect all of important data/modules into a dictionary:

Code from StaticUnconditionalGAN, train_step
1data_dict_ = dict(
2    gen=self.generator,
3    disc=self.discriminator,
4    disc_pred_fake=disc_pred_fake,
5    disc_pred_real=disc_pred_real,
6    fake_imgs=fake_imgs,
7    real_imgs=real_imgs,
8    iteration=curr_iter,
9    batch_size=batch_size)

But in this loss, we will need to provide pred as input. Thus, an example of the data_info is:

1data_info = dict(
2    pred='disc_pred_fake')

Then, the module will automatically construct this mapping from the input data dictionary.

In addition, in general, disc_shift_loss will be applied over real and fake data. In this case, users just need to add this loss module twice, but with different data_info. Our model will automatically add these two items.

Parameters
  • loss_weight (float, optional) – Weight of this loss item. Defaults to 1..

  • data_info (dict, optional) – Dictionary contains the mapping between loss input args and data dictionary. If None, this module will directly pass the input data to the loss function. Defaults to None.

  • loss_name (str, optional) – Name of the loss item. If you want this loss item to be included into the backward graph, loss_ must be the prefix of the name. Defaults to ‘loss_disc_shift’.

forward(*args, **kwargs) torch.Tensor[source]

Forward function.

If self.data_info is not None, a dictionary containing all of the data and necessary modules should be passed into this function. If this dictionary is given as a non-keyword argument, it should be offered as the first argument. If you are using keyword argument, please name it as outputs_dict.

If self.data_info is None, the input argument or key-word argument will be directly passed to loss function, disc_shift_loss.

loss_name() str[source]

Loss Name.

This function must be implemented and will return the name of this loss function. This name will be used to combine different loss items by simple sum operation. In addition, if you want this loss item to be included into the backward graph, loss_ must be the prefix of the name.

Returns

The name of this loss item.

Return type

str

class mmedit.models.losses.loss_comps.disc_auxiliary_loss_comps.GradientPenaltyLossComps(loss_weight: float = 1.0, norm_mode: str = 'pixel', data_info: Optional[dict] = None, loss_name: str = 'loss_gp')[source]

Bases: torch.nn.Module

Gradient Penalty for WGAN-GP.

In the detailed implementation, there are two streams where one uses the pixel-wise gradient norm, but the other adopts normalization along instance (HWC) dimensions. Thus, norm_mode are offered to define which mode you want.

Note for the design of ``data_info``: In MMEditing, almost all of loss modules contain the argument data_info, which can be used for constructing the link between the input items (needed in loss calculation) and the data from the generative model. For example, in the training of GAN model, we will collect all of important data/modules into a dictionary:

Code from StaticUnconditionalGAN, train_step
1data_dict_ = dict(
2    gen=self.generator,
3    disc=self.discriminator,
4    disc_pred_fake=disc_pred_fake,
5    disc_pred_real=disc_pred_real,
6    fake_imgs=fake_imgs,
7    real_imgs=real_imgs,
8    iteration=curr_iter,
9    batch_size=batch_size)

But in this loss, we will need to provide discriminator, real_data, and fake_data as input. Thus, an example of the data_info is:

1data_info = dict(
2    discriminator='disc',
3    real_data='real_imgs',
4    fake_data='fake_imgs')

Then, the module will automatically construct this mapping from the input data dictionary.

Parameters
  • loss_weight (float, optional) – Weight of this loss item. Defaults to 1..

  • data_info (dict, optional) – Dictionary contains the mapping between loss input args and data dictionary. If None, this module will directly pass the input data to the loss function. Defaults to None.

  • norm_mode (str) – This argument decides along which dimension the norm of the gradients will be calculated. Currently, we support [“pixel” , “HWC”]. Defaults to “pixel”.

  • loss_name (str, optional) – Name of the loss item. If you want this loss item to be included into the backward graph, loss_ must be the prefix of the name. Defaults to ‘loss_gp’.

forward(*args, **kwargs) torch.Tensor[source]

Forward function.

If self.data_info is not None, a dictionary containing all of the data and necessary modules should be passed into this function. If this dictionary is given as a non-keyword argument, it should be offered as the first argument. If you are using keyword argument, please name it as outputs_dict.

If self.data_info is None, the input argument or key-word argument will be directly passed to loss function, gradient_penalty_loss.

loss_name() str[source]

Loss Name.

This function must be implemented and will return the name of this loss function. This name will be used to combine different loss items by simple sum operation. In addition, if you want this loss item to be included into the backward graph, loss_ must be the prefix of the name.

Returns

The name of this loss item.

Return type

str

class mmedit.models.losses.loss_comps.disc_auxiliary_loss_comps.R1GradientPenaltyComps(loss_weight: float = 1.0, norm_mode: str = 'pixel', interval: int = 1, data_info: Optional[dict] = None, use_apex_amp: bool = False, loss_name: str = 'loss_r1_gp')[source]

Bases: torch.nn.Module

R1 gradient penalty for WGAN-GP.

R1 regularizer comes from: “Which Training Methods for GANs do actually Converge?” ICML’2018

Different from original gradient penalty, this regularizer only penalized gradient w.r.t. real data.

Note for the design of ``data_info``: In MMEditing, almost all of loss modules contain the argument data_info, which can be used for constructing the link between the input items (needed in loss calculation) and the data from the generative model. For example, in the training of GAN model, we will collect all of important data/modules into a dictionary:

Code from StaticUnconditionalGAN, train_step
1data_dict_ = dict(
2    gen=self.generator,
3    disc=self.discriminator,
4    disc_pred_fake=disc_pred_fake,
5    disc_pred_real=disc_pred_real,
6    fake_imgs=fake_imgs,
7    real_imgs=real_imgs,
8    iteration=curr_iter,
9    batch_size=batch_size)

But in this loss, we will need to provide discriminator and real_data as input. Thus, an example of the data_info is:

1data_info = dict(
2    discriminator='disc',
3    real_data='real_imgs')

Then, the module will automatically construct this mapping from the input data dictionary.

Parameters
  • loss_weight (float, optional) – Weight of this loss item. Defaults to 1..

  • data_info (dict, optional) – Dictionary contains the mapping between loss input args and data dictionary. If None, this module will directly pass the input data to the loss function. Defaults to None.

  • norm_mode (str) – This argument decides along which dimension the norm of the gradients will be calculated. Currently, we support [“pixel” , “HWC”]. Defaults to “pixel”.

  • interval (int, optional) – The interval of calculating this loss. Defaults to 1.

  • loss_name (str, optional) – Name of the loss item. If you want this loss item to be included into the backward graph, loss_ must be the prefix of the name. Defaults to ‘loss_r1_gp’.

forward(*args, **kwargs) torch.Tensor[source]

Forward function.

If self.data_info is not None, a dictionary containing all of the data and necessary modules should be passed into this function. If this dictionary is given as a non-keyword argument, it should be offered as the first argument. If you are using keyword argument, please name it as outputs_dict.

If self.data_info is None, the input argument or key-word argument will be directly passed to loss function, r1_gradient_penalty_loss.

loss_name() str[source]

Loss Name.

This function must be implemented and will return the name of this loss function. This name will be used to combine different loss items by simple sum operation. In addition, if you want this loss item to be included into the backward graph, loss_ must be the prefix of the name.

Returns

The name of this loss item.

Return type

str

Read the Docs v: latest
Versions
master
latest
stable
zyh-re-docs
zyh-doc-notfound-extend
zyh-api-rendering
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.