Shortcuts

mmedit.models.editors.dic

Package Contents

Classes

DIC

DIC model for Face Super-Resolution.

DICNet

DIC network structure for face super-resolution.

FeedbackBlock

Feedback Block of DIC.

FeedbackBlockCustom

Custom feedback block, will be used as the first feedback block.

FeedbackBlockHeatmapAttention

Feedback block with HeatmapAttention.

FeedbackHourglass

Feedback Hourglass model for face landmark.

LightCNN

LightCNN discriminator with input size 128 x 128.

MaxFeature

Conv2d or Linear layer with max feature selector.

class mmedit.models.editors.dic.DIC(generator, pixel_loss, align_loss, discriminator=None, gan_loss=None, feature_loss=None, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None)[源代码]

Bases: mmedit.models.editors.srgan.SRGAN

DIC model for Face Super-Resolution.

Paper: Deep Face Super-Resolution with Iterative Collaboration between

Attentive Recovery and Landmark Estimation.

参数
  • generator (dict) – Config for the generator.

  • pixel_loss (dict) – Config for the pixel loss.

  • align_loss (dict) – Config for the align loss.

  • discriminator (dict) – Config for the discriminator. Default: None.

  • gan_loss (dict) – Config for the gan loss. Default: None.

  • feature_loss (dict) – Config for the feature loss. Default: None.

  • train_cfg (dict) – Config for train. Default: None.

  • test_cfg (dict) – Config for testing. Default: None.

  • init_cfg (dict, optional) – The weight initialized config for BaseModule. Default: None.

  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor. Default: None.

forward_tensor(inputs, data_samples=None, training=False)

Forward tensor. Returns result of simple forward.

参数
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

  • training (bool) – Whether is training. Default: False.

返回

results of forward inference and

forward train.

返回类型

(Tensor | Tuple[List[Tensor]])

if_run_g()

Calculates whether need to run the generator step.

if_run_d()

Calculates whether need to run the discriminator step.

g_step(batch_outputs, batch_gt_data)

G step of GAN: Calculate losses of generator.

参数
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

返回

Dict of losses.

返回类型

dict

d_step_with_optim(batch_outputs, batch_gt_data, optim_wrapper)

D step with optim of GAN: Calculate losses of discriminator and run optim.

参数
  • batch_outputs (Tuple[Tensor]) – Batch output of generator.

  • batch_gt_data (Tuple[Tensor]) – Batch GT data.

  • optim_wrapper (OptimWrapper) – Optim wrapper of discriminator.

返回

Dict of parsed losses.

返回类型

dict

static extract_gt_data(data_samples)

extract gt data from data samples.

参数

data_samples (list) – List of EditDataSample.

返回

Extract gt data.

返回类型

Tensor

class mmedit.models.editors.dic.DICNet(in_channels, out_channels, mid_channels, num_blocks=6, hg_mid_channels=256, hg_num_keypoints=68, num_steps=4, upscale_factor=8, detach_attention=False, prelu_init=0.2, num_heatmaps=5, num_fusion_blocks=7)[源代码]

Bases: mmengine.model.BaseModule

DIC network structure for face super-resolution.

Paper: Deep Face Super-Resolution with Iterative Collaboration between

Attentive Recovery and Landmark Estimation

参数
  • in_channels (int) – Number of channels in the input image

  • out_channels (int) – Number of channels in the output image

  • mid_channels (int) – Channel number of intermediate features. Default: 64

  • num_blocks (tuple[int]) – Block numbers in the trunk network. Default: 6

  • hg_mid_channels (int) – Channel number of intermediate features of HourGlass. Default: 256

  • hg_num_keypoints (int) – Keypoint number of HourGlass. Default: 68

  • num_steps (int) – Number of iterative steps. Default: 4

  • upscale_factor (int) – Upsampling factor. Default: 8

  • detach_attention (bool) – Detached from the current tensor for heatmap or not.

  • prelu_init (float) – init of PReLU. Default: 0.2

  • num_heatmaps (int) – Number of heatmaps. Default: 5

  • num_fusion_blocks (int) – Number of fusion blocks. Default: 7

forward(x)

Forward function.

参数

x (Tensor) – Input tensor.

返回

Forward results. sr_outputs (list[Tensor]): forward sr results. heatmap_outputs (list[Tensor]): forward heatmap results.

返回类型

Tensor

class mmedit.models.editors.dic.FeedbackBlock(mid_channels, num_blocks, upscale_factor, padding=2, prelu_init=0.2)[源代码]

Bases: torch.nn.Module

Feedback Block of DIC.

It has a style of:

----- Module ----->
  ^            |
  |____________|
参数
  • mid_channels (int) – Number of channels in the intermediate features.

  • num_blocks (int) – Number of blocks.

  • upscale_factor (int) – upscale factor.

  • padding (int) – Padding size. Default: 2.

  • prelu_init (float) – init of PReLU. Default: 0.2

forward(x)

Forward function.

参数

x (Tensor) – Input tensor with shape (n, c, h, w).

返回

Forward results.

返回类型

Tensor

class mmedit.models.editors.dic.FeedbackBlockCustom(in_channels, mid_channels, num_blocks, upscale_factor)[源代码]

Bases: FeedbackBlock

Custom feedback block, will be used as the first feedback block.

参数
  • in_channels (int) – Number of channels in the input features.

  • mid_channels (int) – Number of channels in the intermediate features.

  • num_blocks (int) – Number of blocks.

  • upscale_factor (int) – upscale factor.

forward(x)

Forward function.

参数

x (Tensor) – Input tensor with shape (n, c, h, w).

返回

Forward results.

返回类型

Tensor

class mmedit.models.editors.dic.FeedbackBlockHeatmapAttention(mid_channels, num_blocks, upscale_factor, num_heatmaps, num_fusion_blocks, padding=2, prelu_init=0.2)[源代码]

Bases: FeedbackBlock

Feedback block with HeatmapAttention.

参数
  • in_channels (int) – Number of channels in the input features.

  • mid_channels (int) – Number of channels in the intermediate features.

  • num_blocks (int) – Number of blocks.

  • upscale_factor (int) – upscale factor.

  • padding (int) – Padding size. Default: 2.

  • prelu_init (float) – init of PReLU. Default: 0.2

forward(x, heatmap)

Forward function.

参数
  • x (Tensor) – Input feature tensor.

  • heatmap (Tensor) – Input heatmap tensor.

返回

Forward results.

返回类型

Tensor

class mmedit.models.editors.dic.FeedbackHourglass(mid_channels, num_keypoints)[源代码]

Bases: mmengine.model.BaseModule

Feedback Hourglass model for face landmark.

It has a style of:

-- preprocessing ----- Hourglass ----->
                   ^               |
                   |_______________|
参数
  • mid_channels (int) – Number of channels in the intermediate features.

  • num_keypoints (int) – Number of keypoints.

forward(x, last_hidden=None)

Forward function.

参数
  • x (Tensor) – Input tensor with shape (n, c, h, w).

  • last_hidden (Tensor | None) – The feedback of FeedbackHourglass. In first step, last_hidden=None. Otherwise, last_hidden is the past output of FeedbackHourglass. Default: None.

返回

Heatmap of facial landmark. feedback (Tensor): Feedback Tensor.

返回类型

heatmap (Tensor)

class mmedit.models.editors.dic.LightCNN(in_channels)[源代码]

Bases: mmengine.model.BaseModule

LightCNN discriminator with input size 128 x 128.

It is used to train DICGAN.

参数

in_channels (int) – Channel number of inputs.

forward(x)

Forward function.

参数

x (Tensor) – Input tensor.

返回

Forward results.

返回类型

Tensor

init_weights(pretrained=None, strict=True)

Init weights for models.

参数
  • pretrained (str, optional) – Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None.

  • strict (boo, optional) – Whether strictly load the pretrained model. Defaults to True.

class mmedit.models.editors.dic.MaxFeature(in_channels, out_channels, kernel_size=3, stride=1, padding=1, filter_type='conv2d')[源代码]

Bases: torch.nn.Module

Conv2d or Linear layer with max feature selector.

Generate feature maps with double channels, split them and select the max

feature.

参数
  • in_channels (int) – Channel number of inputs.

  • out_channels (int) – Channel number of outputs.

  • kernel_size (int or tuple) – Size of the convolving kernel.

  • stride (int or tuple, optional) – Stride of the convolution. Default: 1

  • padding (int or tuple, optional) – Zero-padding added to both sides of the input. Default: 1

  • filter_type (str) – Type of filter. Options are ‘conv2d’ and ‘linear’. Default: ‘conv2d’.

forward(x)

Forward function.

参数

x (Tensor) – Input tensor.

返回

Forward results.

返回类型

Tensor

Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.