Shortcuts

mmedit.models.editors.global_local

Package Contents

Classes

GLDecoder

Decoder used in Global&Local model.

GLDilationNeck

Dilation Backbone used in Global&Local model.

GLDiscs

Discriminators in Global&Local.

GLEncoder

Encoder used in Global&Local model.

GLEncoderDecoder

Encoder-Decoder used in Global&Local model.

GLInpaintor

Inpaintor for global&local method.

class mmedit.models.editors.global_local.GLDecoder(in_channels=256, norm_cfg=None, act_cfg=dict(type='ReLU'), out_act='clip')[源代码]

Bases: mmengine.model.BaseModule

Decoder used in Global&Local model.

This implementation follows: Globally and locally Consistent Image Completion

参数
  • in_channels (int) – Channel number of input feature.

  • norm_cfg (dict) – Config dict to build norm layer.

  • act_cfg (dict) – Config dict for activation layer, “relu” by default.

  • out_act (str) – Output activation type, “clip” by default. Noted that in our implementation, we clip the output with range [-1, 1].

forward(x)

Forward Function.

参数

x (torch.Tensor) – Input tensor with shape of (n, c, h, w).

返回

Output tensor with shape of (n, c, h’, w’).

返回类型

torch.Tensor

class mmedit.models.editors.global_local.GLDilationNeck(in_channels=256, conv_type='conv', norm_cfg=None, act_cfg=dict(type='ReLU'), **kwargs)[源代码]

Bases: mmengine.model.BaseModule

Dilation Backbone used in Global&Local model.

This implementation follows: Globally and locally Consistent Image Completion

参数
  • in_channels (int) – Channel number of input feature.

  • conv_type (str) – The type of conv module. In DeepFillv1 model, the conv_type should be ‘conv’. In DeepFillv2 model, the conv_type should be ‘gated_conv’.

  • norm_cfg (dict) – Config dict to build norm layer.

  • act_cfg (dict) – Config dict for activation layer, “relu” by default.

  • kwargs (keyword arguments) –

_conv_type
forward(x)

Forward Function.

参数

x (torch.Tensor) – Input tensor with shape of (n, c, h, w).

返回

Output tensor with shape of (n, c, h’, w’).

返回类型

torch.Tensor

class mmedit.models.editors.global_local.GLDiscs(global_disc_cfg, local_disc_cfg)[源代码]

Bases: mmengine.model.BaseModule

Discriminators in Global&Local.

This discriminator contains a local discriminator and a global discriminator as described in the original paper: Globally and locally Consistent Image Completion

参数
  • global_disc_cfg (dict) – Config dict to build global discriminator.

  • local_disc_cfg (dict) – Config dict to build local discriminator.

forward(x)

Forward function.

参数

x (tuple[torch.Tensor]) – Contains global image and the local image patch.

返回

Contains the prediction from discriminators in global image and local image patch.

返回类型

tuple[torch.Tensor]

init_weights()

Init weights for models.

class mmedit.models.editors.global_local.GLEncoder(norm_cfg=None, act_cfg=dict(type='ReLU'))[源代码]

Bases: mmengine.model.BaseModule

Encoder used in Global&Local model.

This implementation follows: Globally and locally Consistent Image Completion

参数
  • norm_cfg (dict) – Config dict to build norm layer.

  • act_cfg (dict) – Config dict for activation layer, “relu” by default.

forward(x)

Forward Function.

参数

x (torch.Tensor) – Input tensor with shape of (n, c, h, w).

返回

Output tensor with shape of (n, c, h’, w’).

返回类型

torch.Tensor

class mmedit.models.editors.global_local.GLEncoderDecoder(encoder=dict(type='GLEncoder'), decoder=dict(type='GLDecoder'), dilation_neck=dict(type='GLDilationNeck'))[源代码]

Bases: mmengine.model.BaseModule

Encoder-Decoder used in Global&Local model.

This implementation follows: Globally and locally Consistent Image Completion

The architecture of the encoder-decoder is: (conv2d x 6) –> (dilated conv2d x 4) –> (conv2d or deconv2d x 7)

参数
  • encoder (dict) – Config dict to encoder.

  • decoder (dict) – Config dict to build decoder.

  • dilation_neck (dict) – Config dict to build dilation neck.

forward(x)

Forward Function.

参数

x (torch.Tensor) – Input tensor with shape of (n, c, h, w).

返回

Output tensor with shape of (n, c, h’, w’).

返回类型

torch.Tensor

class mmedit.models.editors.global_local.GLInpaintor(data_preprocessor: dict, encdec: dict, disc=None, loss_gan=None, loss_gp=None, loss_disc_shift=None, loss_composed_percep=None, loss_out_percep=False, loss_l1_hole=None, loss_l1_valid=None, loss_tv=None, train_cfg=None, test_cfg=None, init_cfg: Optional[dict] = None)[源代码]

Bases: mmedit.models.base_models.OneStageInpaintor

Inpaintor for global&local method.

This inpaintor is implemented according to the paper: Globally and Locally Consistent Image Completion

Importantly, this inpaintor is an example for using custom training schedule based on OneStageInpaintor.

The training pipeline of global&local is as following:

if cur_iter < iter_tc:
    update generator with only l1 loss
else:
    update discriminator
    if cur_iter > iter_td:
        update generator with l1 loss and adversarial loss

The new attribute cur_iter is added for recording current number of iteration. The train_cfg contains the setting of the training schedule:

train_cfg = dict(
    start_iter=0,
    disc_step=1,
    iter_tc=90000,
    iter_td=100000
)

iter_tc and iter_td correspond to the notation \(T_C\) and \(T_D\) of theoriginal paper.

参数
  • generator (dict) – Config for encoder-decoder style generator.

  • disc (dict) – Config for discriminator.

  • loss_gan (dict) – Config for adversarial loss.

  • loss_gp (dict) – Config for gradient penalty loss.

  • loss_disc_shift (dict) – Config for discriminator shift loss.

  • loss_composed_percep (dict) – Config for perceptural and style loss with composed image as input.

  • loss_out_percep (dict) – Config for perceptural and style loss with direct output as input.

  • loss_l1_hole (dict) – Config for l1 loss in the hole.

  • loss_l1_valid (dict) – Config for l1 loss in the valid region.

  • loss_tv (dict) – Config for total variation loss.

  • train_cfg (dict) – Configs for training scheduler. disc_step must be contained for indicates the discriminator updating steps in each training step.

  • test_cfg (dict) – Configs for testing scheduler.

  • init_cfg (dict, optional) – Initialization config dict. Default: None.

generator_loss(fake_res, fake_img, fake_local, gt, mask, masked_img)

Forward function in generator training step.

In this function, we mainly compute the loss items for generator with the given (fake_res, fake_img). In general, the fake_res is the direct output of the generator and the fake_img is the composition of direct output and ground-truth image.

参数
  • fake_res (torch.Tensor) – Direct output of the generator.

  • fake_img (torch.Tensor) – Composition of fake_res and ground-truth image.

  • fake_local (torch.Tensor) – Local image.

  • gt (torch.Tensor) – Ground-truth image.

  • mask (torch.Tensor) – Mask image.

  • masked_img (torch.Tensor) – Composition of mask image and ground-truth image.

返回

A tuple containing two dictionaries. The first one is the result dict, which contains the results computed within this function for visualization. The second one is the loss dict, containing loss items computed in this function.

返回类型

tuple[dict]

train_step(data: List[dict], optim_wrapper)

Train step function.

In this function, the inpaintor will finish the train step following the pipeline:

  1. get fake res/image

  2. optimize discriminator (if in current schedule)

  3. optimize generator (if in current schedule)

If self.train_cfg.disc_step > 1, the train step will contain multiple iterations for optimizing discriminator with different input data and sonly one iteration for optimizing generator after disc_step iterations for discriminator.

参数
  • data (List[dict]) – Batch of data as input.

  • optim_wrapper (dict[torch.optim.Optimizer]) – Dict with optimizers for generator and discriminator (if have).

返回

Dict with loss, information for logger, the number of samples and results for visualization.

返回类型

dict

Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.