Shortcuts

mmedit.models.editors.aotgan

Package Contents

Classes

AOTDecoder

Decoder used in AOT-GAN model.

AOTEncoder

Encoder used in AOT-GAN model.

AOTEncoderDecoder

Encoder-Decoder used in AOT-GAN model.

AOTInpaintor

Inpaintor for AOT-GAN method.

AOTBlockNeck

Dilation backbone used in AOT-GAN model.

class mmedit.models.editors.aotgan.AOTDecoder(in_channels=256, mid_channels=128, out_channels=3, act_cfg=dict(type='ReLU'))[源代码]

Bases: mmengine.model.BaseModule

Decoder used in AOT-GAN model.

This implementation follows: Aggregated Contextual Transformations for High-Resolution Image Inpainting

参数
  • in_channels (int, optional) – Channel number of input feature. Default: 256.

  • mid_channels (int, optional) – Channel number of middle feature. Default: 128.

  • out_channels (int, optional) – Channel number of output feature. Default 3.

  • act_cfg (dict, optional) – Config dict for activation layer, “relu” by default.

forward(x)

Forward Function.

参数

x (Tensor) – Input tensor with shape of (n, c, h, w).

返回

Output tensor with shape of (n, c, h’, w’).

返回类型

Tensor

class mmedit.models.editors.aotgan.AOTEncoder(in_channels=4, mid_channels=64, out_channels=256, act_cfg=dict(type='ReLU'))[源代码]

Bases: mmengine.model.BaseModule

Encoder used in AOT-GAN model.

This implementation follows: Aggregated Contextual Transformations for High-Resolution Image Inpainting

参数
  • in_channels (int, optional) – Channel number of input feature. Default: 4.

  • mid_channels (int, optional) – Channel number of middle feature. Default: 64.

  • out_channels (int, optional) – Channel number of output feature. Default: 256.

  • act_cfg (dict, optional) – Config dict for activation layer, “relu” by default.

forward(x)

Forward Function.

参数

x (Tensor) – Input tensor with shape of (n, c, h, w).

返回

Output tensor with shape of (n, c, h’, w’).

返回类型

Tensor

class mmedit.models.editors.aotgan.AOTEncoderDecoder(encoder=dict(type='AOTEncoder'), decoder=dict(type='AOTDecoder'), dilation_neck=dict(type='AOTBlockNeck'))[源代码]

Bases: mmedit.models.editors.global_local.GLEncoderDecoder

Encoder-Decoder used in AOT-GAN model.

This implementation follows: Aggregated Contextual Transformations for High-Resolution Image Inpainting The architecture of the encoder-decoder is: (conv2d x 3) –> (dilated conv2d x 8) –> (conv2d or deconv2d x 3).

参数
  • encoder (dict) – Config dict to encoder.

  • decoder (dict) – Config dict to build decoder.

  • dilation_neck (dict) – Config dict to build dilation neck.

class mmedit.models.editors.aotgan.AOTInpaintor(data_preprocessor: Union[dict, mmengine.config.Config], encdec: dict, disc: Optional[dict] = None, loss_gan: Optional[dict] = None, loss_gp: Optional[dict] = None, loss_disc_shift: Optional[dict] = None, loss_composed_percep: Optional[dict] = None, loss_out_percep: bool = False, loss_l1_hole: Optional[dict] = None, loss_l1_valid: Optional[dict] = None, loss_tv: Optional[dict] = None, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, init_cfg: Optional[dict] = None)[源代码]

Bases: mmedit.models.base_models.OneStageInpaintor

Inpaintor for AOT-GAN method.

This inpaintor is implemented according to the paper: Aggregated Contextual Transformations for High-Resolution Image Inpainting

forward_train_d(data_batch, is_real, is_disc, mask)

Forward function in discriminator training step.

In this function, we compute the prediction for each data batch (real or fake). Meanwhile, the standard gan loss will be computed with several proposed losses for stable training.

参数
  • data_batch (torch.Tensor) – Batch of real data or fake data.

  • is_real (bool) – If True, the gan loss will regard this batch as real data. Otherwise, the gan loss will regard this batch as fake data.

  • is_disc (bool) – If True, this function is called in discriminator training step. Otherwise, this function is called in generator training step. This will help us to compute different types of adversarial loss, like LSGAN.

  • mask (torch.Tensor) – Mask of data.

返回

Contains the loss items computed in this function.

返回类型

dict

generator_loss(fake_res, fake_img, gt, mask, masked_img)

Forward function in generator training step.

In this function, we mainly compute the loss items for generator with the given (fake_res, fake_img). In general, the fake_res is the direct output of the generator and the fake_img is the composition of direct output and ground-truth image.

参数
  • fake_res (torch.Tensor) – Direct output of the generator.

  • fake_img (torch.Tensor) – Composition of fake_res and ground-truth image.

  • gt (torch.Tensor) – Ground-truth image.

  • mask (torch.Tensor) – Mask image.

  • masked_img (torch.Tensor) – Composition of mask image and ground-truth image.

返回

Dict contains the results computed within this

function for visualization and dict contains the loss items computed in this function.

返回类型

tuple(dict)

forward_tensor(inputs, data_samples)

Forward function in tensor mode.

参数
  • inputs (torch.Tensor) – Input tensor.

  • data_samples (List[dict]) – List of data sample dict.

返回

Direct output of the generator and composition of fake_res

and ground-truth image.

返回类型

tuple

train_step(data: List[dict], optim_wrapper)

Train step function.

In this function, the inpaintor will finish the train step following the pipeline: 1. get fake res/image 2. compute reconstruction losses for generator 3. compute adversarial loss for discriminator 4. optimize generator 5. optimize discriminator

参数
  • data (List[dict]) – Batch of data as input.

  • optim_wrapper (dict[torch.optim.Optimizer]) – Dict with optimizers for generator and discriminator (if have).

返回

Dict with loss, information for logger, the number of

samples and results for visualization.

返回类型

dict

class mmedit.models.editors.aotgan.AOTBlockNeck(in_channels=256, dilation_rates=(1, 2, 4, 8), num_aotblock=8, act_cfg=dict(type='ReLU'), **kwargs)[源代码]

Bases: mmengine.model.BaseModule

Dilation backbone used in AOT-GAN model.

This implementation follows: Aggregated Contextual Transformations for High-Resolution Image Inpainting

参数
  • in_channels (int, optional) – Channel number of input feature. Default: 256.

  • dilation_rates (Tuple[int], optional) – The dilation rates used

  • Default (for AOT block.) – (1, 2, 4, 8).

  • num_aotblock (int, optional) – Number of AOT blocks. Default: 8.

  • act_cfg (dict, optional) – Config dict for activation layer, “relu” by default.

  • kwargs (keyword arguments) –

forward(x)
Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.