Shortcuts

TwoStageInpaintor

class mmedit.models.base_models.TwoStageInpaintor(data_preprocessor: Union[dict, mmengine.config.config.Config], encdec: dict, disc=None, loss_gan=None, loss_gp=None, loss_disc_shift=None, loss_composed_percep=None, loss_out_percep=False, loss_l1_hole=None, loss_l1_valid=None, loss_tv=None, train_cfg=None, test_cfg=None, init_cfg: Optional[dict] = None, stage1_loss_type=('loss_l1_hole',), stage2_loss_type=('loss_l1_hole', 'loss_gan'), input_with_ones=True, disc_input_with_mask=False)[source]

Standard two-stage inpaintor with commonly used losses. A two-stage inpaintor contains two encoder-decoder style generators to inpaint masked regions. Currently, we support these loss types in each of two stage inpaintors:

[‘loss_gan’, ‘loss_l1_hole’, ‘loss_l1_valid’, ‘loss_composed_percep’, ‘loss_out_percep’, ‘loss_tv’] The stage1_loss_type and stage2_loss_type should be chosen from these loss types.

Parameters
  • data_preprocessor (dict) – Config of data_preprocessor.

  • encdec (dict) – Config for encoder-decoder style generator.

  • disc (dict) – Config for discriminator.

  • loss_gan (dict) – Config for adversarial loss.

  • loss_gp (dict) – Config for gradient penalty loss.

  • loss_disc_shift (dict) – Config for discriminator shift loss.

  • loss_composed_percep (dict) – Config for perceptual and style loss with composed image as input.

  • loss_out_percep (dict) – Config for perceptual and style loss with direct output as input.

  • loss_l1_hole (dict) – Config for l1 loss in the hole.

  • loss_l1_valid (dict) – Config for l1 loss in the valid region.

  • loss_tv (dict) – Config for total variation loss.

  • train_cfg (dict) – Configs for training scheduler. disc_step must be contained for indicates the discriminator updating steps in each training step.

  • test_cfg (dict) – Configs for testing scheduler.

  • init_cfg (dict, optional) – Initialization config dict.

  • stage1_loss_type (tuple[str]) – Contains the loss names used in the first stage model. Default: (‘loss_l1_hole’).

  • stage2_loss_type (tuple[str]) – Contains the loss names used in the second stage model. Default: (‘loss_l1_hole’, ‘loss_gan’).

  • input_with_ones (bool) – Whether to concatenate an extra ones tensor in input. Default: True.

  • disc_input_with_mask (bool) – Whether to add mask as input in discriminator. Default: False.

calculate_loss_with_type(loss_type, fake_res, fake_img, gt, mask, prefix='stage1_')[source]

Calculate multiple types of losses.

Parameters
  • loss_type (str) – Type of the loss.

  • fake_res (torch.Tensor) – Direct results from model.

  • fake_img (torch.Tensor) – Composited results from model.

  • gt (torch.Tensor) – Ground-truth tensor.

  • mask (torch.Tensor) – Mask tensor.

  • prefix (str, optional) – Prefix for loss name. Defaults to ‘stage1_’. # noqa

Returns

Contain loss value with its name.

Return type

dict

forward_tensor(inputs, data_samples)[source]

Forward function in tensor mode.

Parameters
  • inputs (torch.Tensor) – Input tensor.

  • data_samples (List[dict]) – List of data sample dict.

Returns

Dict contains output results.

Return type

dict

train_step(data: List[dict], optim_wrapper)[source]

Train step function.

In this function, the inpaintor will finish the train step following the pipeline: 1. get fake res/image 2. optimize discriminator (if have) 3. optimize generator

If self.train_cfg.disc_step > 1, the train step will contain multiple iterations for optimizing discriminator with different input data and only one iteration for optimizing gerator after disc_step iterations for discriminator.

Parameters
  • data (List[dict]) – Batch of data as input.

  • optim_wrapper (dict[torch.optim.Optimizer]) – Dict with optimizers for generator and discriminator (if have).

Returns

Dict with loss, information for logger, the number of samples and results for visualization.

Return type

dict

two_stage_loss(stage1_data, stage2_data, gt, mask, masked_img)[source]

Calculate two-stage loss.

Parameters
  • stage1_data (dict) – Contain stage1 results.

  • stage2_data (dict) – Contain stage2 results..

  • gt (torch.Tensor) – Ground-truth image.

  • mask (torch.Tensor) – Mask image.

  • masked_img (torch.Tensor) – Composition of mask image and ground-truth image.

Returns

Dict contains the results computed within this function for visualization and dict contains the loss items computed in this function.

Return type

tuple(dict)

Read the Docs v: zyh/doc-notfound-extend
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.