Shortcuts

mmedit.models.editors.deepfillv1

Package Contents

Classes

ContextualAttentionModule

Contexture attention module.

ContextualAttentionNeck

Neck with contextual attention module.

DeepFillDecoder

Decoder used in DeepFill model.

DeepFillv1Discriminators

Discriminators used in DeepFillv1 model.

DeepFillEncoder

Encoder used in DeepFill model.

DeepFillRefiner

Refiner used in DeepFill model.

DeepFillv1Inpaintor

Inpaintor for deepfillv1 method.

class mmedit.models.editors.deepfillv1.ContextualAttentionModule(unfold_raw_kernel_size=4, unfold_raw_stride=2, unfold_raw_padding=1, unfold_corr_kernel_size=3, unfold_corr_stride=1, unfold_corr_dilation=1, unfold_corr_padding=1, scale=0.5, fuse_kernel_size=3, softmax_scale=10, return_attention_score=True)[源代码]

Bases: mmengine.model.BaseModule

Contexture attention module.

The details of this module can be found in: Generative Image Inpainting with Contextual Attention

参数
  • unfold_raw_kernel_size (int) – Kernel size used in unfolding raw feature. Default: 4.

  • unfold_raw_stride (int) – Stride used in unfolding raw feature. Default: 2.

  • unfold_raw_padding (int) – Padding used in unfolding raw feature. Default: 1.

  • unfold_corr_kernel_size (int) – Kernel size used in unfolding context for computing correlation maps. Default: 3.

  • unfold_corr_stride (int) – Stride used in unfolding context for computing correlation maps. Default: 1.

  • unfold_corr_dilation (int) – Dilation used in unfolding context for computing correlation maps. Default: 1.

  • unfold_corr_padding (int) – Padding used in unfolding context for computing correlation maps. Default: 1.

  • scale (float) – The resale factor used in resize input features. Default: 0.5.

  • fuse_kernel_size (int) – The kernel size used in fusion module. Default: 3.

  • softmax_scale (float) – The scale factor for softmax function. Default: 10.

  • return_attention_score (bool) – If True, the attention score will be returned. Default: True.

forward(x, context, mask=None)

Forward Function.

参数
  • x (torch.Tensor) – Tensor with shape (n, c, h, w).

  • context (torch.Tensor) – Tensor with shape (n, c, h, w).

  • mask (torch.Tensor) – Tensor with shape (n, 1, h, w). Default: None.

返回

Features after contextural attention.

返回类型

tuple(torch.Tensor)

patch_correlation(x, kernel)

Calculate patch correlation.

参数
  • x (torch.Tensor) – Input tensor.

  • kernel (torch.Tensor) – Kernel tensor.

返回

Tensor with shape of (n, l, h, w).

返回类型

torch.Tensor

patch_copy_deconv(attention_score, context_filter)

Copy patches using deconv.

参数
  • attention_score (torch.Tensor) – Tensor with shape of (n, l , h, w).

  • context_filter (torch.Tensor) – Filter kernel.

返回

Tensor with shape of (n, c, h, w).

返回类型

torch.Tensor

fuse_correlation_map(correlation_map, h_unfold, w_unfold)

Fuse correlation map.

This operation is to fuse correlation map for increasing large consistent correlation regions.

The mechanism behind this op is simple and easy to understand. A standard ‘Eye’ matrix will be applied as a filter on the correlation map in horizontal and vertical direction.

The shape of input correlation map is (n, h_unfold*w_unfold, h, w). When adopting fusing, we will apply convolutional filter in the reshaped feature map with shape of (n, 1, h_unfold*w_fold, h*w).

A simple specification for horizontal direction is shown below:

       (h, (h, (h, (h,
        0)  1)  2)  3)  ...
(h, 0)
(h, 1)      1
(h, 2)          1
(h, 3)              1
...
calculate_unfold_hw(input_size, kernel_size=3, stride=1, dilation=1, padding=0)

Calculate (h, w) after unfolding.

The official implementation of unfold in pytorch will put the dimension (h, w) into L. Thus, this function is just to calculate the (h, w) according to the equation in: https://pytorch.org/docs/stable/nn.html#torch.nn.Unfold

calculate_overlap_factor(attention_score)

Calculate the overlap factor after applying deconv.

参数

attention_score (torch.Tensor) – The attention score with shape of (n, c, h, w).

返回

The overlap factor will be returned.

返回类型

torch.Tensor

mask_correlation_map(correlation_map, mask)

Add mask weight for correlation map.

Add a negative infinity number to the masked regions so that softmax function will result in ‘zero’ in those regions.

参数
  • correlation_map (torch.Tensor) – Correlation map with shape of (n, h_unfold*w_unfold, h_map, w_map).

  • mask (torch.Tensor) – Mask tensor with shape of (n, c, h, w). ‘1’ in the mask indicates masked region while ‘0’ indicates valid region.

返回

Updated correlation map with mask.

返回类型

torch.Tensor

im2col(img, kernel_size, stride=1, padding=0, dilation=1, normalize=False, return_cols=False)

Reshape image-style feature to columns.

This function is used for unfold feature maps to columns. The details of this function can be found in: https://pytorch.org/docs/1.1.0/nn.html?highlight=unfold#torch.nn.Unfold

参数
  • img (torch.Tensor) – Features to be unfolded. The shape of this feature should be (n, c, h, w).

  • kernel_size (int) – In this function, we only support square kernel with same height and width.

  • stride (int) – Stride number in unfolding. Default: 1.

  • padding (int) – Padding number in unfolding. Default: 0.

  • dilation (int) – Dilation number in unfolding. Default: 1.

  • normalize (bool) – If True, the unfolded feature will be normalized. Default: False.

  • return_cols (bool) – The official implementation in PyTorch of unfolding will return features with shape of (n, c*$prod{kernel_size}$, L). If True, the features will be reshaped to (n, L, c, kernel_size, kernel_size). Otherwise, the results will maintain the shape as the official implementation.

返回

Unfolded columns. If return_cols is True, the shape of output tensor is (n, L, c, kernel_size, kernel_size). Otherwise, the shape will be (n, c*$prod{kernel_size}$, L).

返回类型

torch.Tensor

class mmedit.models.editors.deepfillv1.ContextualAttentionNeck(in_channels, conv_type='conv', conv_cfg=None, norm_cfg=None, act_cfg=dict(type='ELU'), contextual_attention_args=dict(softmax_scale=10.0), **kwargs)[源代码]

Bases: mmengine.model.BaseModule

Neck with contextual attention module.

参数
  • in_channels (int) – The number of input channels.

  • conv_type (str) – The type of conv module. In DeepFillv1 model, the conv_type should be ‘conv’. In DeepFillv2 model, the conv_type should be ‘gated_conv’.

  • conv_cfg (dict | None) – Config of conv module. Default: None.

  • norm_cfg (dict | None) – Config of norm module. Default: None.

  • act_cfg (dict | None) – Config of activation layer. Default: dict(type=’ELU’).

  • contextual_attention_args (dict) – Config of contextual attention module. Default: dict(softmax_scale=10.).

  • kwargs (keyword arguments) –

_conv_type
forward(x, mask)

Forward Function.

参数
  • x (torch.Tensor) – Input tensor with shape of (n, c, h, w).

  • mask (torch.Tensor) – Input tensor with shape of (n, 1, h, w).

返回

Output tensor with shape of (n, c, h’, w’).

返回类型

torch.Tensor

class mmedit.models.editors.deepfillv1.DeepFillDecoder(in_channels, conv_type='conv', norm_cfg=None, act_cfg=dict(type='ELU'), out_act_cfg=dict(type='clip', min=- 1.0, max=1.0), channel_factor=1.0, **kwargs)[源代码]

Bases: mmengine.model.BaseModule

Decoder used in DeepFill model.

This implementation follows: Generative Image Inpainting with Contextual Attention

参数
  • in_channels (int) – The number of input channels.

  • conv_type (str) – The type of conv module. In DeepFillv1 model, the conv_type should be ‘conv’. In DeepFillv2 model, the conv_type should be ‘gated_conv’.

  • norm_cfg (dict) – Config dict to build norm layer. Default: None.

  • act_cfg (dict) – Config dict for activation layer, “elu” by default.

  • out_act_cfg (dict) – Config dict for output activation layer. Here, we provide commonly used clamp or clip operation.

  • channel_factor (float) – The scale factor for channel size. Default: 1.

  • kwargs (keyword arguments) –

_conv_type
forward(input_dict)

Forward Function.

参数

input_dict (dict | torch.Tensor) – Input dict with middle features or torch.Tensor.

返回

Output tensor with shape of (n, c, h, w).

返回类型

torch.Tensor

class mmedit.models.editors.deepfillv1.DeepFillv1Discriminators(global_disc_cfg, local_disc_cfg)[源代码]

Bases: mmengine.model.BaseModule

Discriminators used in DeepFillv1 model.

In DeepFillv1 model, the discriminators are independent without any concatenation like Global&Local model. Thus, we call this model DeepFillv1Discriminators. There exist a global discriminator and a local discriminator with global and local input respectively.

The details can be found in: Generative Image Inpainting with Contextual Attention.

参数
  • global_disc_cfg (dict) – Config dict for global discriminator.

  • local_disc_cfg (dict) – Config dict for local discriminator.

forward(x)

Forward function.

参数

x (tuple[torch.Tensor]) – Contains global image and the local image patch.

返回

Contains the prediction from discriminators in global image and local image patch.

返回类型

tuple[torch.Tensor]

init_weights()

Init weights for models.

class mmedit.models.editors.deepfillv1.DeepFillEncoder(in_channels=5, conv_type='conv', norm_cfg=None, act_cfg=dict(type='ELU'), encoder_type='stage1', channel_factor=1.0, **kwargs)[源代码]

Bases: mmengine.model.BaseModule

Encoder used in DeepFill model.

This implementation follows: Generative Image Inpainting with Contextual Attention

参数
  • in_channels (int) – The number of input channels. Default: 5.

  • conv_type (str) – The type of conv module. In DeepFillv1 model, the conv_type should be ‘conv’. In DeepFillv2 model, the conv_type should be ‘gated_conv’.

  • norm_cfg (dict) – Config dict to build norm layer. Default: None.

  • act_cfg (dict) – Config dict for activation layer, “elu” by default.

  • encoder_type (str) – Type of the encoder. Should be one of [‘stage1’, ‘stage2_conv’, ‘stage2_attention’]. Default: ‘stage1’.

  • channel_factor (float) – The scale factor for channel size. Default: 1.

  • kwargs (keyword arguments) –

_conv_type
forward(x)

Forward Function.

参数

x (torch.Tensor) – Input tensor with shape of (n, c, h, w).

返回

Output tensor with shape of (n, c, h’, w’).

返回类型

torch.Tensor

class mmedit.models.editors.deepfillv1.DeepFillRefiner(encoder_attention=dict(type='DeepFillEncoder', encoder_type='stage2_attention'), encoder_conv=dict(type='DeepFillEncoder', encoder_type='stage2_conv'), dilation_neck=dict(type='GLDilationNeck', in_channels=128, act_cfg=dict(type='ELU')), contextual_attention=dict(type='ContextualAttentionNeck', in_channels=128), decoder=dict(type='DeepFillDecoder', in_channels=256))[源代码]

Bases: mmengine.model.BaseModule

Refiner used in DeepFill model.

This implementation follows: Generative Image Inpainting with Contextual Attention.

参数
  • encoder_attention (dict) – Config dict for encoder used in branch with contextual attention module.

  • encoder_conv (dict) – Config dict for encoder used in branch with just convolutional operation.

  • dilation_neck (dict) – Config dict for dilation neck in branch with just convolutional operation.

  • contextual_attention (dict) – Config dict for contextual attention neck.

  • decoder (dict) – Config dict for decoder used to fuse and decode features.

forward(x, mask)

Forward Function.

参数
  • x (torch.Tensor) – Input tensor with shape of (n, c, h, w).

  • mask (torch.Tensor) – Input tensor with shape of (n, 1, h, w).

返回

Output tensor with shape of (n, c, h’, w’).

返回类型

torch.Tensor

class mmedit.models.editors.deepfillv1.DeepFillv1Inpaintor(data_preprocessor: dict, encdec: dict, disc=None, loss_gan=None, loss_gp=None, loss_disc_shift=None, loss_composed_percep=None, loss_out_percep=False, loss_l1_hole=None, loss_l1_valid=None, loss_tv=None, stage1_loss_type=None, stage2_loss_type=None, train_cfg=None, test_cfg=None, init_cfg: Optional[dict] = None)[源代码]

Bases: mmedit.models.base_models.TwoStageInpaintor

Inpaintor for deepfillv1 method.

This inpaintor is implemented according to the paper: Generative image inpainting with contextual attention

Importantly, this inpaintor is an example for using custom training schedule based on TwoStageInpaintor.

The training pipeline of deepfillv1 is as following:

if cur_iter < iter_tc:
    update generator with only l1 loss
else:
    update discriminator
    if cur_iter > iter_td:
        update generator with l1 loss and adversarial loss

The new attribute cur_iter is added for recording current number of iteration. The train_cfg contains the setting of the training schedule:

train_cfg = dict(
    start_iter=0,
    disc_step=1,
    iter_tc=90000,
    iter_td=100000
)

iter_tc and iter_td correspond to the notation \(T_C\) and \(T_D\) of the original paper.

参数
  • generator (dict) – Config for encoder-decoder style generator.

  • disc (dict) – Config for discriminator.

  • loss_gan (dict) – Config for adversarial loss.

  • loss_gp (dict) – Config for gradient penalty loss.

  • loss_disc_shift (dict) – Config for discriminator shift loss.

  • loss_composed_percep (dict) – Config for perceptual and style loss with composed image as input.

  • loss_out_percep (dict) – Config for perceptual and style loss with direct output as input.

  • loss_l1_hole (dict) – Config for l1 loss in the hole.

  • loss_l1_valid (dict) – Config for l1 loss in the valid region.

  • loss_tv (dict) – Config for total variation loss.

  • train_cfg (dict) – Configs for training scheduler. disc_step must be contained for indicates the discriminator updating steps in each training step.

  • test_cfg (dict) – Configs for testing scheduler.

  • init_cfg (dict, optional) – Initialization config dict.

forward_train_d(data_batch, is_real, is_disc)

Forward function in discriminator training step.

In this function, we modify the default implementation with only one discriminator. In DeepFillv1 model, they use two separated discriminators for global and local consistency.

参数
  • data_batch (torch.Tensor) – Batch of real data or fake data.

  • is_real (bool) – If True, the gan loss will regard this batch as real data. Otherwise, the gan loss will regard this batch as fake data.

  • is_disc (bool) – If True, this function is called in discriminator training step. Otherwise, this function is called in generator training step. This will help us to compute different types of adversarial loss, like LSGAN.

返回

Contains the loss items computed in this function.

返回类型

dict

two_stage_loss(stage1_data, stage2_data, gt, mask, masked_img)

Calculate two-stage loss.

参数
  • stage1_data (dict) – Contain stage1 results.

  • stage2_data (dict) – Contain stage2 results.

  • gt (torch.Tensor) – Ground-truth image.

  • mask (torch.Tensor) – Mask image.

  • masked_img (torch.Tensor) – Composition of mask image and ground-truth image.

返回

Dict contains the results computed within this function for visualization and dict contains the loss items computed in this function.

返回类型

tuple(dict)

calculate_loss_with_type(loss_type, fake_res, fake_img, gt, mask, prefix='stage1_', fake_local=None)

Calculate multiple types of losses.

参数
  • loss_type (str) – Type of the loss.

  • fake_res (torch.Tensor) – Direct results from model.

  • fake_img (torch.Tensor) – Composited results from model.

  • gt (torch.Tensor) – Ground-truth tensor.

  • mask (torch.Tensor) – Mask tensor.

  • prefix (str, optional) – Prefix for loss name. Defaults to ‘stage1_’. # noqa

  • fake_local (torch.Tensor, optional) – Local results from model. Defaults to None.

返回

Contain loss value with its name.

返回类型

dict

train_step(data: List[dict], optim_wrapper)

Train step function.

In this function, the inpaintor will finish the train step following the pipeline:

  1. get fake res/image

  2. optimize discriminator (if have)

  3. optimize generator

If self.train_cfg.disc_step > 1, the train step will contain multiple iterations for optimizing discriminator with different input data and only one iteration for optimizing gerator after disc_step iterations for discriminator.

参数
  • data (List[dict]) – Batch of data as input.

  • optim_wrapper (dict[torch.optim.Optimizer]) – Dict with optimizers for generator and discriminator (if have).

返回

Dict with loss, information for logger, the number of samples and results for visualization.

返回类型

dict

Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.