Shortcuts

mmedit.models.base_models.two_stage

Module Contents

Classes

TwoStageInpaintor

Standard two-stage inpaintor with commonly used losses. A two-stage

class mmedit.models.base_models.two_stage.TwoStageInpaintor(data_preprocessor: Union[dict, mmengine.config.Config], encdec: dict, disc: Optional[dict] = None, loss_gan: Optional[dict] = None, loss_gp: Optional[dict] = None, loss_disc_shift: Optional[dict] = None, loss_composed_percep: Optional[dict] = None, loss_out_percep: bool = False, loss_l1_hole: Optional[dict] = None, loss_l1_valid: Optional[dict] = None, loss_tv: Optional[dict] = None, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, init_cfg: Optional[dict] = None, stage1_loss_type: Optional[Sequence[str]] = ('loss_l1_hole',), stage2_loss_type: Optional[Sequence[str]] = ('loss_l1_hole', 'loss_gan'), input_with_ones: bool = True, disc_input_with_mask: bool = False)[源代码]

Bases: mmedit.models.base_models.one_stage.OneStageInpaintor

Standard two-stage inpaintor with commonly used losses. A two-stage inpaintor contains two encoder-decoder style generators to inpaint masked regions. Currently, we support these loss types in each of two stage inpaintors:

[‘loss_gan’, ‘loss_l1_hole’, ‘loss_l1_valid’, ‘loss_composed_percep’, ‘loss_out_percep’, ‘loss_tv’] The stage1_loss_type and stage2_loss_type should be chosen from these loss types.

参数
  • data_preprocessor (dict) – Config of data_preprocessor.

  • encdec (dict) – Config for encoder-decoder style generator.

  • disc (dict) – Config for discriminator.

  • loss_gan (dict) – Config for adversarial loss.

  • loss_gp (dict) – Config for gradient penalty loss.

  • loss_disc_shift (dict) – Config for discriminator shift loss.

  • loss_composed_percep (dict) – Config for perceptual and style loss with composed image as input.

  • loss_out_percep (dict) – Config for perceptual and style loss with direct output as input.

  • loss_l1_hole (dict) – Config for l1 loss in the hole.

  • loss_l1_valid (dict) – Config for l1 loss in the valid region.

  • loss_tv (dict) – Config for total variation loss.

  • train_cfg (dict) – Configs for training scheduler. disc_step must be contained for indicates the discriminator updating steps in each training step.

  • test_cfg (dict) – Configs for testing scheduler.

  • init_cfg (dict, optional) – Initialization config dict.

  • stage1_loss_type (tuple[str]) – Contains the loss names used in the first stage model. Default: (‘loss_l1_hole’).

  • stage2_loss_type (tuple[str]) – Contains the loss names used in the second stage model. Default: (‘loss_l1_hole’, ‘loss_gan’).

  • input_with_ones (bool) – Whether to concatenate an extra ones tensor in input. Default: True.

  • disc_input_with_mask (bool) – Whether to add mask as input in discriminator. Default: False.

forward_tensor(inputs: torch.Tensor, data_samples: mmedit.utils.SampleList) Tuple[torch.Tensor, torch.Tensor][源代码]

Forward function in tensor mode.

参数
  • inputs (torch.Tensor) – Input tensor.

  • data_samples (List[dict]) – List of data sample dict.

返回

Dict contains output results.

返回类型

dict

two_stage_loss(stage1_data: dict, stage2_data: dict, gt: torch.Tensor, mask: torch.Tensor, masked_img: torch.Tensor) Tuple[dict, dict][源代码]

Calculate two-stage loss.

参数
  • stage1_data (dict) – Contain stage1 results.

  • stage2_data (dict) – Contain stage2 results..

  • gt (torch.Tensor) – Ground-truth image.

  • mask (torch.Tensor) – Mask image.

  • masked_img (torch.Tensor) – Composition of mask image and ground-truth image.

返回

Dict contains the results computed within this function for visualization and dict contains the loss items computed in this function.

返回类型

tuple(dict)

calculate_loss_with_type(loss_type: str, fake_res: torch.Tensor, fake_img: torch.Tensor, gt: torch.Tensor, mask: torch.Tensor, prefix: Optional[str] = 'stage1_') dict[源代码]

Calculate multiple types of losses.

参数
  • loss_type (str) – Type of the loss.

  • fake_res (torch.Tensor) – Direct results from model.

  • fake_img (torch.Tensor) – Composited results from model.

  • gt (torch.Tensor) – Ground-truth tensor.

  • mask (torch.Tensor) – Mask tensor.

  • prefix (str, optional) – Prefix for loss name. Defaults to ‘stage1_’. # noqa

返回

Contain loss value with its name.

返回类型

dict

train_step(data: List[dict], optim_wrapper: mmengine.optim.OptimWrapperDict) dict[源代码]

Train step function.

In this function, the inpaintor will finish the train step following the pipeline: 1. get fake res/image 2. optimize discriminator (if have) 3. optimize generator

If self.train_cfg.disc_step > 1, the train step will contain multiple iterations for optimizing discriminator with different input data and only one iteration for optimizing gerator after disc_step iterations for discriminator.

参数
  • data (List[dict]) – Batch of data as input.

  • optim_wrapper (dict[torch.optim.Optimizer]) – Dict with optimizers for generator and discriminator (if have).

返回

Dict with loss, information for logger, the number of samples and results for visualization.

返回类型

dict

Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.