Shortcuts

mmedit.models.editors.deepfillv1.deepfillv1

Module Contents

Classes

DeepFillv1Inpaintor

Inpaintor for deepfillv1 method.

class mmedit.models.editors.deepfillv1.deepfillv1.DeepFillv1Inpaintor(data_preprocessor: dict, encdec: dict, disc=None, loss_gan=None, loss_gp=None, loss_disc_shift=None, loss_composed_percep=None, loss_out_percep=False, loss_l1_hole=None, loss_l1_valid=None, loss_tv=None, stage1_loss_type=None, stage2_loss_type=None, train_cfg=None, test_cfg=None, init_cfg: Optional[dict] = None)[源代码]

Bases: mmedit.models.base_models.TwoStageInpaintor

Inpaintor for deepfillv1 method.

This inpaintor is implemented according to the paper: Generative image inpainting with contextual attention

Importantly, this inpaintor is an example for using custom training schedule based on TwoStageInpaintor.

The training pipeline of deepfillv1 is as following:

if cur_iter < iter_tc:
    update generator with only l1 loss
else:
    update discriminator
    if cur_iter > iter_td:
        update generator with l1 loss and adversarial loss

The new attribute cur_iter is added for recording current number of iteration. The train_cfg contains the setting of the training schedule:

train_cfg = dict(
    start_iter=0,
    disc_step=1,
    iter_tc=90000,
    iter_td=100000
)

iter_tc and iter_td correspond to the notation \(T_C\) and \(T_D\) of the original paper.

参数
  • generator (dict) – Config for encoder-decoder style generator.

  • disc (dict) – Config for discriminator.

  • loss_gan (dict) – Config for adversarial loss.

  • loss_gp (dict) – Config for gradient penalty loss.

  • loss_disc_shift (dict) – Config for discriminator shift loss.

  • loss_composed_percep (dict) – Config for perceptual and style loss with composed image as input.

  • loss_out_percep (dict) – Config for perceptual and style loss with direct output as input.

  • loss_l1_hole (dict) – Config for l1 loss in the hole.

  • loss_l1_valid (dict) – Config for l1 loss in the valid region.

  • loss_tv (dict) – Config for total variation loss.

  • train_cfg (dict) – Configs for training scheduler. disc_step must be contained for indicates the discriminator updating steps in each training step.

  • test_cfg (dict) – Configs for testing scheduler.

  • init_cfg (dict, optional) – Initialization config dict.

forward_train_d(data_batch, is_real, is_disc)[源代码]

Forward function in discriminator training step.

In this function, we modify the default implementation with only one discriminator. In DeepFillv1 model, they use two separated discriminators for global and local consistency.

参数
  • data_batch (torch.Tensor) – Batch of real data or fake data.

  • is_real (bool) – If True, the gan loss will regard this batch as real data. Otherwise, the gan loss will regard this batch as fake data.

  • is_disc (bool) – If True, this function is called in discriminator training step. Otherwise, this function is called in generator training step. This will help us to compute different types of adversarial loss, like LSGAN.

返回

Contains the loss items computed in this function.

返回类型

dict

two_stage_loss(stage1_data, stage2_data, gt, mask, masked_img)[源代码]

Calculate two-stage loss.

参数
  • stage1_data (dict) – Contain stage1 results.

  • stage2_data (dict) – Contain stage2 results.

  • gt (torch.Tensor) – Ground-truth image.

  • mask (torch.Tensor) – Mask image.

  • masked_img (torch.Tensor) – Composition of mask image and ground-truth image.

返回

Dict contains the results computed within this function for visualization and dict contains the loss items computed in this function.

返回类型

tuple(dict)

calculate_loss_with_type(loss_type, fake_res, fake_img, gt, mask, prefix='stage1_', fake_local=None)[源代码]

Calculate multiple types of losses.

参数
  • loss_type (str) – Type of the loss.

  • fake_res (torch.Tensor) – Direct results from model.

  • fake_img (torch.Tensor) – Composited results from model.

  • gt (torch.Tensor) – Ground-truth tensor.

  • mask (torch.Tensor) – Mask tensor.

  • prefix (str, optional) – Prefix for loss name. Defaults to ‘stage1_’. # noqa

  • fake_local (torch.Tensor, optional) – Local results from model. Defaults to None.

返回

Contain loss value with its name.

返回类型

dict

train_step(data: List[dict], optim_wrapper)[源代码]

Train step function.

In this function, the inpaintor will finish the train step following the pipeline:

  1. get fake res/image

  2. optimize discriminator (if have)

  3. optimize generator

If self.train_cfg.disc_step > 1, the train step will contain multiple iterations for optimizing discriminator with different input data and only one iteration for optimizing gerator after disc_step iterations for discriminator.

参数
  • data (List[dict]) – Batch of data as input.

  • optim_wrapper (dict[torch.optim.Optimizer]) – Dict with optimizers for generator and discriminator (if have).

返回

Dict with loss, information for logger, the number of samples and results for visualization.

返回类型

dict

Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.