Shortcuts

mmedit.models.editors

Package Contents

Classes

AOTBlockNeck

Dilation backbone used in AOT-GAN model.

AOTEncoderDecoder

Encoder-Decoder used in AOT-GAN model.

AOTInpaintor

Inpaintor for AOT-GAN method.

IDLossModel

Face id loss model.

BasicVSR

BasicVSR model for video super-resolution.

BasicVSRNet

BasicVSR network structure for video super-resolution.

BasicVSRPlusPlusNet

BasicVSR++ network structure.

BigGAN

Impelmentation of `Large Scale GAN Training for High Fidelity Natural

CAIN

CAIN model for Video Interpolation.

CAINNet

CAIN network structure.

CycleGAN

CycleGAN model for unpaired image-to-image translation.

DCGAN

Impelmentation of `Unsupervised Representation Learning with Deep

DDIMScheduler

`DDIMScheduler` support the diffusion and reverse process formulated

DDPMScheduler

DenoisingUnet

Denoising Unet. This network receives a diffused image x_t and

ContextualAttentionModule

Contexture attention module.

ContextualAttentionNeck

Neck with contextual attention module.

DeepFillDecoder

Decoder used in DeepFill model.

DeepFillEncoder

Encoder used in DeepFill model.

DeepFillRefiner

Refiner used in DeepFill model.

DeepFillv1Discriminators

Discriminators used in DeepFillv1 model.

DeepFillv1Inpaintor

Inpaintor for deepfillv1 method.

DeepFillEncoderDecoder

Two-stage encoder-decoder structure used in DeepFill model.

DIC

DIC model for Face Super-Resolution.

DICNet

DIC network structure for face super-resolution.

FeedbackBlock

Feedback Block of DIC.

FeedbackBlockCustom

Custom feedback block, will be used as the first feedback block.

FeedbackBlockHeatmapAttention

Feedback block with HeatmapAttention.

LightCNN

LightCNN discriminator with input size 128 x 128.

MaxFeature

Conv2d or Linear layer with max feature selector.

DIM

Deep Image Matting model.

ClipWrapper

Clip Models wrapper for disco-diffusion.

DiscoDiffusion

Disco Diffusion (DD) is a Google Colab Notebook which leverages an AI

EDSRNet

EDSR network structure.

EDVR

EDVR model for video super-resolution.

EDVRNet

EDVR network structure for video super-resolution.

EG3D

Implementation of `Efficient Geometry-aware 3D Generative Adversarial

ESRGAN

Enhanced SRGAN model for single image super-resolution.

RRDBNet

Networks consisting of Residual in Residual Dense Block, which is used

FBADecoder

Decoder for FBA matting.

FBAResnetDilated

ResNet-based encoder for FBA image matting.

FLAVR

FLAVR model for video interpolation.

FLAVRNet

PyTorch implementation of FLAVR for video frame interpolation.

GCA

Guided Contextual Attention image matting model.

GGAN

Impelmentation of Geomoetric GAN.

GLEANStyleGANv2

GLEAN (using StyleGANv2) architecture for super-resolution.

GLDecoder

Decoder used in Global&Local model.

GLDilationNeck

Dilation Backbone used in Global&Local model.

GLEncoder

Encoder used in Global&Local model.

GLEncoderDecoder

Encoder-Decoder used in Global&Local model.

AblatedDiffusionModel

Guided diffusion Model.

IconVSRNet

IconVSR network structure for video super-resolution.

DepthwiseIndexBlock

Depthwise index block.

HolisticIndexBlock

Holistic Index Block.

IndexedUpsample

Indexed upsample module.

IndexNet

IndexNet matting model.

IndexNetDecoder

Decoder for IndexNet.

IndexNetEncoder

Encoder for IndexNet.

InstColorization

Colorization InstColorization method.

LIIF

LIIF model for single image super-resolution.

MLPRefiner

Multilayer perceptrons (MLPs), refiner used in LIIF.

LSGAN

Impelmentation of Least Squares Generative Adversarial Networks.

MSPIEStyleGAN2

MS-PIE StyleGAN2.

PESinGAN

Positional Encoding in SinGAN.

NAFBaseline

The original version of Baseline model in "Simple Baseline for Image

NAFBaselineLocal

The original version of Baseline model in "Simple Baseline for Image

NAFNet

NAFNet.

NAFNetLocal

The original version of NAFNetLocal in "Simple Baseline for Image

MaskConvModule

Mask convolution module.

PartialConv2d

Implementation for partial convolution.

PConvDecoder

Decoder with partial conv.

PConvEncoder

Encoder with partial conv.

PConvEncoderDecoder

Encoder-Decoder with partial conv module.

PConvInpaintor

Inpaintor for Partial Convolution method.

ProgressiveGrowingGAN

Progressive Growing Unconditional GAN.

Pix2Pix

Pix2Pix model for paired image-to-image translation.

PlainDecoder

Simple decoder from Deep Image Matting.

PlainRefiner

Simple refiner from Deep Image Matting.

RDNNet

RDN model for single image super-resolution.

RealBasicVSR

RealBasicVSR model for real-world video super-resolution.

RealBasicVSRNet

RealBasicVSR network structure for real-world video super-resolution.

RealESRGAN

Real-ESRGAN model for single image super-resolution.

UNetDiscriminatorWithSpectralNorm

A U-Net discriminator with spectral normalization.

SAGAN

Impelmentation of Self-Attention Generative Adversarial Networks.

SinGAN

SinGAN.

SRCNNNet

SRCNN network structure for image super resolution.

SRGAN

SRGAN model for single image super-resolution.

ModifiedVGG

A modified VGG discriminator with input size 128 x 128.

MSRResNet

Modified SRResNet.

StyleGAN1

Implementation of `A Style-Based Generator Architecture for Generative

StyleGAN2

Impelmentation of `Analyzing and Improving the Image Quality of

StyleGAN3

Impelmentation of Alias-Free Generative Adversarial Networks. # noqa.

StyleGAN3Generator

StyleGAN3 Generator.

TDAN

TDAN model for video super-resolution.

TDANNet

TDAN network structure for video super-resolution.

TOFlowVFINet

PyTorch implementation of TOFlow for video frame interpolation.

TOFlowVSRNet

PyTorch implementation of TOFlow.

ToFResBlock

ResNet architecture.

LTE

Learnable Texture Extractor.

TTSR

TTSR model for Reference-based Image Super-Resolution.

SearchTransformer

Search texture reference by transformer.

TTSRDiscriminator

A discriminator for TTSR.

TTSRNet

TTSR network structure (main-net) for reference-based super-resolution.

WGANGP

Impelmentation of Improved Training of Wasserstein GANs.

class mmedit.models.editors.AOTBlockNeck(in_channels=256, dilation_rates=(1, 2, 4, 8), num_aotblock=8, act_cfg=dict(type='ReLU'), **kwargs)

Bases: torch.nn.Module

Dilation backbone used in AOT-GAN model.

This implementation follows: Aggregated Contextual Transformations for High-Resolution Image Inpainting

Parameters
  • in_channels (int, optional) – Channel number of input feature. Default: 256.

  • dilation_rates (Tuple[int], optional) – The dilation rates used

  • Default (for AOT block.) – (1, 2, 4, 8).

  • num_aotblock (int, optional) – Number of AOT blocks. Default: 8.

  • act_cfg (dict, optional) – Config dict for activation layer, “relu” by default.

  • kwargs (keyword arguments) –

forward(x)
class mmedit.models.editors.AOTEncoderDecoder(encoder=dict(type='AOTEncoder'), decoder=dict(type='AOTDecoder'), dilation_neck=dict(type='AOTBlockNeck'))

Bases: mmedit.models.editors.global_local.GLEncoderDecoder

Encoder-Decoder used in AOT-GAN model.

This implementation follows: Aggregated Contextual Transformations for High-Resolution Image Inpainting The architecture of the encoder-decoder is: (conv2d x 3) –> (dilated conv2d x 8) –> (conv2d or deconv2d x 3).

Parameters
  • encoder (dict) – Config dict to encoder.

  • decoder (dict) – Config dict to build decoder.

  • dilation_neck (dict) – Config dict to build dilation neck.

class mmedit.models.editors.AOTInpaintor(data_preprocessor: Union[dict, mmengine.config.Config], encdec, disc=None, loss_gan=None, loss_gp=None, loss_disc_shift=None, loss_composed_percep=None, loss_out_percep=False, loss_l1_hole=None, loss_l1_valid=None, loss_tv=None, train_cfg=None, test_cfg=None, init_cfg: Optional[dict] = None)

Bases: mmedit.models.base_models.OneStageInpaintor

Inpaintor for AOT-GAN method.

This inpaintor is implemented according to the paper: Aggregated Contextual Transformations for High-Resolution Image Inpainting

forward_train_d(data_batch, is_real, is_disc, mask)

Forward function in discriminator training step.

In this function, we compute the prediction for each data batch (real or fake). Meanwhile, the standard gan loss will be computed with several proposed losses for stable training.

Parameters
  • data_batch (torch.Tensor) – Batch of real data or fake data.

  • is_real (bool) – If True, the gan loss will regard this batch as real data. Otherwise, the gan loss will regard this batch as fake data.

  • is_disc (bool) – If True, this function is called in discriminator training step. Otherwise, this function is called in generator training step. This will help us to compute different types of adversarial loss, like LSGAN.

  • mask (torch.Tensor) – Mask of data.

Returns

Contains the loss items computed in this function.

Return type

dict

generator_loss(fake_res, fake_img, gt, mask, masked_img)

Forward function in generator training step.

In this function, we mainly compute the loss items for generator with the given (fake_res, fake_img). In general, the fake_res is the direct output of the generator and the fake_img is the composition of direct output and ground-truth image.

Parameters
  • fake_res (torch.Tensor) – Direct output of the generator.

  • fake_img (torch.Tensor) – Composition of fake_res and ground-truth image.

  • gt (torch.Tensor) – Ground-truth image.

  • mask (torch.Tensor) – Mask image.

  • masked_img (torch.Tensor) – Composition of mask image and ground-truth image.

Returns

Dict contains the results computed within this

function for visualization and dict contains the loss items computed in this function.

Return type

tuple(dict)

forward_tensor(inputs, data_samples)

Forward function in tensor mode.

Parameters
  • inputs (torch.Tensor) – Input tensor.

  • data_samples (List[dict]) – List of data sample dict.

Returns

Direct output of the generator and composition of fake_res

and ground-truth image.

Return type

tuple

train_step(data: List[dict], optim_wrapper)

Train step function.

In this function, the inpaintor will finish the train step following the pipeline: 1. get fake res/image 2. compute reconstruction losses for generator 3. compute adversarial loss for discriminator 4. optimize generator 5. optimize discriminator

Parameters
  • data (List[dict]) – Batch of data as input.

  • optim_wrapper (dict[torch.optim.Optimizer]) – Dict with optimizers for generator and discriminator (if have).

Returns

Dict with loss, information for logger, the number of

samples and results for visualization.

Return type

dict

class mmedit.models.editors.IDLossModel(ir_se50_weights=None)

Bases: torch.nn.Module

Face id loss model.

Parameters

ir_se50_weights (str, optional) – Url of ir-se50 weights. Defaults to None.

_ir_se50_url = https://gg0ltg.by.files.1drv.com/y4m3fNNszG03z9n8JQ7EhdtQKW8tQVQMFBisPVRgoXi_UfP8pKSSqv8RJNmHy2Ja...
extract_feats(x)

Extracting face features.

Parameters

x (torch.Tensor) – Image tensor of faces.

Returns

Face features.

Return type

torch.Tensor

forward(pred=None, gt=None)

Calculate face loss.

Parameters
  • pred (torch.Tensor, optional) – Predictions of face images. Defaults to None.

  • gt (torch.Tensor, optional) – Ground truth of face images. Defaults to None.

Returns

A tuple contain face similarity loss and

improvement.

Return type

Tuple(float, float)

class mmedit.models.editors.BasicVSR(generator, pixel_loss, ensemble=None, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None)

Bases: mmedit.models.BaseEditModel

BasicVSR model for video super-resolution.

Note that this model is used for IconVSR.

Paper:

BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond, CVPR, 2021

Parameters
  • generator (dict) – Config for the generator structure.

  • pixel_loss (dict) – Config for pixel-wise loss.

  • ensemble (dict) – Config for ensemble. Default: None.

  • train_cfg (dict) – Config for training. Default: None.

  • test_cfg (dict) – Config for testing. Default: None.

  • init_cfg (dict, optional) – The weight initialized config for BaseModule.

  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor.

check_if_mirror_extended(lrs)

Check whether the input is a mirror-extended sequence.

If mirror-extended, the i-th (i=0, …, t-1) frame is equal to the (t-1-i)-th frame.

Parameters

lrs (tensor) – Input LR images with shape (n, t, c, h, w)

forward_train(inputs, data_samples=None, **kwargs)

Forward training. Returns dict of losses of training.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

Returns

Dict of losses.

Return type

dict

forward_inference(inputs, data_samples=None, **kwargs)

Forward inference. Returns predictions of validation, testing.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

Returns

predictions.

Return type

List[EditDataSample]

class mmedit.models.editors.BasicVSRNet(mid_channels=64, num_blocks=30, spynet_pretrained=None)

Bases: mmengine.model.BaseModule

BasicVSR network structure for video super-resolution.

Support only x4 upsampling.

Paper:

BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond, CVPR, 2021

Parameters
  • mid_channels (int) – Channel number of the intermediate features. Default: 64.

  • num_blocks (int) – Number of residual blocks in each propagation branch. Default: 30.

  • spynet_pretrained (str) – Pre-trained model path of SPyNet. Default: None.

check_if_mirror_extended(lrs)

Check whether the input is a mirror-extended sequence.

If mirror-extended, the i-th (i=0, …, t-1) frame is equal to the (t-1-i)-th frame.

Parameters

lrs (tensor) – Input LR images with shape (n, t, c, h, w)

compute_flow(lrs)

Compute optical flow using SPyNet for feature warping.

Note that if the input is an mirror-extended sequence, ‘flows_forward’ is not needed, since it is equal to ‘flows_backward.flip(1)’.

Parameters

lrs (tensor) – Input LR images with shape (n, t, c, h, w)

Returns

Optical flow. ‘flows_forward’ corresponds to the

flows used for forward-time propagation (current to previous). ‘flows_backward’ corresponds to the flows used for backward-time propagation (current to next).

Return type

tuple(Tensor)

forward(lrs)

Forward function for BasicVSR.

Parameters

lrs (Tensor) – Input LR sequence with shape (n, t, c, h, w).

Returns

Output HR sequence with shape (n, t, c, 4h, 4w).

Return type

Tensor

class mmedit.models.editors.BasicVSRPlusPlusNet(mid_channels=64, num_blocks=7, max_residue_magnitude=10, is_low_res_input=True, spynet_pretrained=None, cpu_cache_length=100)

Bases: mmengine.model.BaseModule

BasicVSR++ network structure.

Support either x4 upsampling or same size output.

Paper:

BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment

Parameters
  • mid_channels (int, optional) – Channel number of the intermediate features. Default: 64.

  • num_blocks (int, optional) – The number of residual blocks in each propagation branch. Default: 7.

  • max_residue_magnitude (int) – The maximum magnitude of the offset residue (Eq. 6 in paper). Default: 10.

  • is_low_res_input (bool, optional) – Whether the input is low-resolution or not. If False, the output resolution is equal to the input resolution. Default: True.

  • spynet_pretrained (str, optional) – Pre-trained model path of SPyNet. Default: None.

  • cpu_cache_length (int, optional) – When the length of sequence is larger than this value, the intermediate features are sent to CPU. This saves GPU memory, but slows down the inference speed. You can increase this number if you have a GPU with large memory. Default: 100.

check_if_mirror_extended(lqs)

Check whether the input is a mirror-extended sequence.

If mirror-extended, the i-th (i=0, …, t-1) frame is equal to the (t-1-i)-th frame.

Parameters

lqs (tensor) – Input low quality (LQ) sequence with shape (n, t, c, h, w).

compute_flow(lqs)

Compute optical flow using SPyNet for feature alignment.

Note that if the input is an mirror-extended sequence, ‘flows_forward’ is not needed, since it is equal to ‘flows_backward.flip(1)’.

Parameters

lqs (tensor) – Input low quality (LQ) sequence with shape (n, t, c, h, w).

Returns

Optical flow. ‘flows_forward’ corresponds to the

flows used for forward-time propagation (current to previous). ‘flows_backward’ corresponds to the flows used for backward-time propagation (current to next).

Return type

tuple(Tensor)

propagate(feats, flows, module_name)

Propagate the latent features throughout the sequence.

Parameters
  • dict (feats) – Features from previous branches. Each component is a list of tensors with shape (n, c, h, w).

  • flows (tensor) – Optical flows with shape (n, t - 1, 2, h, w).

  • module_name (str) – The name of the propagation branches. Can either be ‘backward_1’, ‘forward_1’, ‘backward_2’, ‘forward_2’.

Returns

A dictionary containing all the propagated

features. Each key in the dictionary corresponds to a propagation branch, which is represented by a list of tensors.

Return type

dict(list[tensor])

upsample(lqs, feats)

Compute the output image given the features.

Parameters
  • lqs (tensor) – Input low quality (LQ) sequence with shape (n, t, c, h, w).

  • feats (dict) – The features from the propagation branches.

Returns

Output HR sequence with shape (n, t, c, 4h, 4w).

Return type

Tensor

forward(lqs)

Forward function for BasicVSR++.

Parameters

lqs (tensor) – Input low quality (LQ) sequence with shape (n, t, c, h, w).

Returns

Output HR sequence with shape (n, t, c, 4h, 4w).

Return type

Tensor

class mmedit.models.editors.BigGAN(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, noise_size: Optional[int] = None, num_classes: Optional[int] = None, ema_config: Optional[Dict] = None)

Bases: mmedit.models.base_models.BaseConditionalGAN

Impelmentation of Large Scale GAN Training for High Fidelity Natural Image Synthesis (BigGAN).

Detailed architecture can be found in :class:~`mmgen.models.architectures.biggan.generator_discriminator.BigGANGenerator` # noqa and :class:~`mmgen.models.architectures.biggan.generator_discriminator.BigGANDiscriminator` # noqa

Parameters
  • generator (ModelType) – The config or model of the generator.

  • discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.

  • data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or GANDataPreprocessor.

  • generator_steps (int) – Number of times the generator was completely updated before the discriminator is updated. Defaults to 1.

  • discriminator_steps (int) – Number of times the discriminator was completely updated before the generator is updated. Defaults to 1.

  • noise_size (Optional[int]) – Size of the input noise vector. Default to 128.

  • num_classes (Optional[int]) – The number classes you would like to generate. Defaults to None.

  • ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.

disc_loss(disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor) Tuple

Get disc loss. BigGAN use hinge loss to train the discriminator.

Parameters
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • disc_pred_real (Tensor) – Discriminator’s prediction of the real images.

Returns

Loss value and a dict of log variables.

Return type

tuple[Tensor, dict]

gen_loss(disc_pred_fake)

Get disc loss. BigGAN use hinge loss to train the generator.

Parameters

disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

Returns

Loss value and a dict of log variables.

Return type

tuple[Tensor, dict]

train_discriminator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train discriminator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

train_generator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train generator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

class mmedit.models.editors.CAIN(generator, pixel_loss, train_cfg=None, test_cfg=None, required_frames=2, step_frames=1, init_cfg=None, data_preprocessor=None)

Bases: mmedit.models.base_models.BasicInterpolator

CAIN model for Video Interpolation.

Paper: Channel Attention Is All You Need for Video Frame Interpolation Ref repo: https://github.com/myungsub/CAIN

Parameters
  • generator (dict) – Config for the generator structure.

  • pixel_loss (dict) – Config for pixel-wise loss.

  • train_cfg (dict) – Config for training. Default: None.

  • test_cfg (dict) – Config for testing. Default: None.

  • required_frames (int) – Required frames in each process. Default: 2

  • step_frames (int) – Step size of video frame interpolation. Default: 1

  • init_cfg (dict, optional) – The weight initialized config for BaseModule.

  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor.

init_cfg

Initialization config dict.

Type

dict, optional

data_preprocessor

Used for pre-processing data sampled by dataloader to the format accepted by forward().

Type

BaseDataPreprocessor

forward_inference(inputs, data_samples=None)

Forward inference. Returns predictions of validation, testing, and simple inference.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

Returns

predictions.

Return type

List[EditDataSample]

class mmedit.models.editors.CAINNet(in_channels=3, kernel_size=3, num_block_groups=5, num_block_layers=12, depth=3, reduction=16, norm=None, padding=7, act=nn.LeakyReLU(0.2, True), init_cfg=None)

Bases: mmengine.model.BaseModule

CAIN network structure.

Paper: Channel Attention Is All You Need for Video Frame Interpolation. Ref repo: https://github.com/myungsub/CAIN

Parameters
  • in_channels (int) – Channel number of inputs. Default: 3.

  • kernel_size (int) – Kernel size of CAINNet. Default: 3.

  • num_block_groups (int) – Number of block groups. Default: 5.

  • num_block_layers (int) – Number of blocks in a group. Default: 12.

  • depth (int) – Down scale depth, scale = 2**depth. Default: 3.

  • reduction (int) – Channel reduction of CA. Default: 16.

  • norm (str | None) – Normalization layer. If it is None, no normalization is performed. Default: None.

  • padding (int) – Padding of CAINNet. Default: 7.

  • act (function) – activate function. Default: nn.LeakyReLU(0.2, True).

  • init_cfg (dict, optional) – Initialization config dict. Default: None.

forward(imgs, padding_flag=False)

Forward function.

Parameters
  • imgs (Tensor) – Input tensor with shape (n, 2, c, h, w).

  • padding_flag (bool) – Padding or not. Default: False.

Returns

Forward results.

Return type

Tensor

class mmedit.models.editors.CycleGAN(*args, buffer_size=50, loss_config=dict(cycle_loss_weight=10.0, id_loss_weight=0.5), **kwargs)

Bases: mmedit.models.base_models.BaseTranslationModel

CycleGAN model for unpaired image-to-image translation.

Ref: Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks

forward_test(img, target_domain, **kwargs)

Forward function for testing.

Parameters
  • img (tensor) – Input image tensor.

  • target_domain (str) – Target domain of output image.

  • kwargs (dict) – Other arguments.

Returns

Forward results.

Return type

dict

_get_disc_loss(outputs)

Backward function for the discriminators.

Parameters

outputs (dict) – Dict of forward results.

Returns

Discriminators’ loss and loss dict.

Return type

dict

_get_gen_loss(outputs)

Backward function for the generators.

Parameters

outputs (dict) – Dict of forward results.

Returns

Generators’ loss and loss dict.

Return type

dict

_get_opposite_domain(domain)

Get the opposite domain respect to the input domain.

Parameters

domain (str) – The input domain.

Returns

The opposite domain.

Return type

str

train_step(data: dict, optim_wrapper: mmengine.optim.OptimWrapperDict)

Training step function.

Parameters
  • data_batch (dict) – Dict of the input data batch.

  • optimizer (dict[torch.optim.Optimizer]) – Dict of optimizers for the generators and discriminators.

  • ddp_reducer (Reducer | None, optional) – Reducer from ddp. It is used to prepare for backward() in ddp. Defaults to None.

  • running_status (dict | None, optional) – Contains necessary basic information for training, e.g., iteration number. Defaults to None.

Returns

Dict of loss, information for logger, the number of samples and results for visualization.

Return type

dict

test_step(data: dict) mmedit.utils.typing.SampleList

Gets the generated image of given data. Same as val_step().

Parameters

data (dict) – Data sampled from metric specific sampler. More detials in Metrics and Evaluator.

Returns

A list of EditDataSample contain generated results.

Return type

SampleList

val_step(data: dict) mmedit.utils.typing.SampleList

Gets the generated image of given data. Same as val_step().

Parameters

data (dict) – Data sampled from metric specific sampler. More detials in Metrics and Evaluator.

Returns

A list of EditDataSample contain generated results.

Return type

SampleList

class mmedit.models.editors.DCGAN(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, noise_size: Optional[int] = None, ema_config: Optional[Dict] = None, loss_config: Optional[Dict] = None)

Bases: mmedit.models.base_models.BaseGAN

Impelmentation of Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks.

Paper link:

<https://arxiv.org/abs/1511.06434>`_ (DCGAN).

Detailed architecture can be found in :class:~`mmgen.models.architectures.dcgan.generator_discriminator.DCGANGenerator` # noqa and :class:~`mmgen.models.architectures.dcgan.generator_discriminator.DCGANDiscriminator` # noqa

disc_loss(disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor) Tuple

Get disc loss. DCGAN use the vanilla gan loss to train the discriminator.

Parameters
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • disc_pred_real (Tensor) – Discriminator’s prediction of the real images.

Returns

Loss value and a dict of log variables.

Return type

tuple[Tensor, dict]

gen_loss(disc_pred_fake: torch.Tensor) Tuple

Get gen loss. DCGAN use the vanilla gan loss to train the generator.

Parameters

disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

Returns

Loss value and a dict of log variables.

Return type

tuple[Tensor, dict]

train_discriminator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train discriminator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

train_generator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train generator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

class mmedit.models.editors.DDIMScheduler(num_train_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', variance_type='learned_range', timestep_values=None, clip_sample=True, set_alpha_to_one=True)

`DDIMScheduler` support the diffusion and reverse process formulated in https://arxiv.org/abs/2010.02502.

The code is heavily influenced by https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py. # noqa The difference is that we ensemble gradient-guided sampling in step function.

Parameters
  • num_train_timesteps (int, optional) – _description_. Defaults to 1000.

  • beta_start (float, optional) – _description_. Defaults to 0.0001.

  • beta_end (float, optional) – _description_. Defaults to 0.02.

  • beta_schedule (str, optional) – _description_. Defaults to “linear”.

  • variance_type (str, optional) – _description_. Defaults to ‘learned_range’.

  • timestep_values (_type_, optional) – _description_. Defaults to None.

  • clip_sample (bool, optional) – _description_. Defaults to True.

  • set_alpha_to_one (bool, optional) – _description_. Defaults to True.

set_timesteps(num_inference_steps, offset=0)
_get_variance(timestep, prev_timestep)
step(model_output: Union[torch.FloatTensor, numpy.ndarray], timestep: int, sample: Union[torch.FloatTensor, numpy.ndarray], cond_fn=None, cond_kwargs={}, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None)
add_noise(original_samples, noise, timesteps)
__len__()
class mmedit.models.editors.DDPMScheduler(num_train_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', trained_betas=None, variance_type='fixed_small', clip_sample=True)
set_timesteps(num_inference_steps)
_get_variance(t, predicted_variance=None, variance_type=None)
step(model_output: Union[torch.FloatTensor], timestep: int, sample: Union[torch.FloatTensor], predict_epsilon=True, generator=None)
add_noise(original_samples, noise, timesteps)
abstract training_loss(model, x_0, t)
abstract sample_timestep()
__len__()
class mmedit.models.editors.DenoisingUnet(image_size, in_channels=3, base_channels=128, resblocks_per_downsample=3, num_timesteps=1000, use_rescale_timesteps=False, dropout=0, embedding_channels=- 1, num_classes=0, use_fp16=False, channels_cfg=None, output_cfg=dict(mean='eps', var='learned_range'), norm_cfg=dict(type='GN', num_groups=32), act_cfg=dict(type='SiLU', inplace=False), shortcut_kernel_size=1, use_scale_shift_norm=False, resblock_updown=False, num_heads=4, time_embedding_mode='sin', time_embedding_cfg=None, resblock_cfg=dict(type='DenoisingResBlock'), attention_cfg=dict(type='MultiHeadAttention'), downsample_conv=True, upsample_conv=True, downsample_cfg=dict(type='DenoisingDownsample'), upsample_cfg=dict(type='DenoisingUpsample'), attention_res=[16, 8], pretrained=None)

Bases: mmengine.model.BaseModule

Denoising Unet. This network receives a diffused image x_t and current timestep t, and returns a output_dict corresponding to the passed output_cfg.

output_cfg defines the number of channels and the meaning of the output. output_cfg mainly contains keys of mean and var, denoting how the network outputs mean and variance required for the denoising process. For mean: 1. dict(mean='EPS'): Model will predict noise added in the

diffusion process, and the output_dict will contain a key named eps_t_pred.

  1. dict(mean='START_X'): Model will direct predict the mean of the

    original image x_0, and the output_dict will contain a key named x_0_pred.

  2. dict(mean='X_TM1_PRED'): Model will predict the mean of diffused

    image at t-1 timestep, and the output_dict will contain a key named x_tm1_pred.

For var: 1. dict(var='FIXED_SMALL') or dict(var='FIXED_LARGE'): Variance in

the denoising process is regarded as a fixed value. Therefore only ‘mean’ will be predicted, and the output channels will equal to the input image (e.g., three channels for RGB image.)

  1. dict(var='LEARNED'): Model will predict log_variance in the

    denoising process, and the output_dict will contain a key named log_var.

  2. dict(var='LEARNED_RANGE'): Model will predict an interpolation

    factor and the log_variance will be calculated as factor * upper_bound + (1-factor) * lower_bound. The output_dict will contain a key named factor.

If var is not FIXED_SMALL or FIXED_LARGE, the number of output channels will be the double of input channels, where the first half part contains predicted mean values and the other part is the predicted variance values. Otherwise, the number of output channels equals to the input channels, only containing the predicted mean values.

Parameters
  • image_size (int | list[int]) – The size of image to denoise.

  • in_channels (int, optional) – The input channels of the input image. Defaults as 3.

  • base_channels (int, optional) – The basic channel number of the generator. The other layers contain channels based on this number. Defaults to 128.

  • resblocks_per_downsample (int, optional) – Number of ResBlock used between two downsample operations. The number of ResBlock between upsample operations will be the same value to keep symmetry. Defaults to 3.

  • num_timesteps (int, optional) – The total timestep of the denoising process and the diffusion process. Defaults to 1000.

  • use_rescale_timesteps (bool, optional) – Whether rescale the input timesteps in range of [0, 1000]. Defaults to True.

  • dropout (float, optional) – The probability of dropout operation of each ResBlock. Pass 0 to do not use dropout. Defaults as 0.

  • embedding_channels (int, optional) – The output channels of time embedding layer and label embedding layer. If not passed (or passed -1), output channels of the embedding layers will set as four times of base_channels. Defaults to -1.

  • num_classes (int, optional) – The number of conditional classes. If set to 0, this model will be degraded to an unconditional model. Defaults to 0.

  • channels_cfg (list | dict[list], optional) – Config for input channels of the intermedia blocks. If list is passed, each element of the list indicates the scale factor for the input channels of the current block with regard to the base_channels. For block i, the input and output channels should be channels_cfg[i] * base_channels and channels_cfg[i+1] * base_channels If dict is provided, the key of the dict should be the output scale and corresponding value should be a list to define channels. Default: Please refer to _defualt_channels_cfg.

  • output_cfg (dict, optional) – Config for output variables. Defaults to dict(mean='eps', var='learned_range').

  • norm_cfg (dict, optional) – The config for normalization layers. Defaults to dict(type='GN', num_groups=32).

  • act_cfg (dict, optional) – The config for activation layers. Defaults to dict(type='SiLU', inplace=False).

  • shortcut_kernel_size (int, optional) – The kernel size for shortcut conv in ResBlocks. The value of this argument will overwrite the default value of resblock_cfg. Defaults to 3.

  • use_scale_shift_norm (bool, optional) – Whether perform scale and shift after normalization operation. Defaults to True.

  • num_heads (int, optional) – The number of attention heads. Defaults to 4.

  • time_embedding_mode (str, optional) – Embedding method of time_embedding. Defaults to ‘sin’.

  • time_embedding_cfg (dict, optional) – Config for time_embedding. Defaults to None.

  • resblock_cfg (dict, optional) – Config for ResBlock. Defaults to dict(type='DenoisingResBlock').

  • attention_cfg (dict, optional) – Config for attention operation. Defaults to dict(type='MultiHeadAttention').

  • upsample_conv (bool, optional) – Whether use conv in upsample block. Defaults to True.

  • downsample_conv (bool, optional) – Whether use conv operation in downsample block. Defaults to True.

  • upsample_cfg (dict, optional) – Config for upsample blocks. Defaults to dict(type='DenoisingDownsample').

  • downsample_cfg (dict, optional) – Config for downsample blocks. Defaults to dict(type='DenoisingUpsample').

  • attention_res (int | list[int], optional) – Resolution of feature maps to apply attention operation. Defaults to [16, 8].

  • pretrained (str | dict, optional) – Path for the pretrained model or dict containing information for pretained models whose necessary key is ‘ckpt_path’. Besides, you can also provide ‘prefix’ to load the generator part from the whole state dict. Defaults to None.

_default_channels_cfg
forward(x_t, t, label=None, return_noise=False)

Forward function. :param x_t: Diffused image at timestep t to denoise. :type x_t: torch.Tensor :param t: Current timestep. :type t: torch.Tensor :param label: You can directly give a

batch of label through a torch.Tensor or offer a callable function to sample a batch of label data. Otherwise, the None indicates to use the default label sampler.

Parameters

return_noise (bool, optional) – If True, inputted x_t and t will be returned in a dict with output desired by output_cfg. Defaults to False.

Returns

If not return_noise

Return type

torch.Tensor | dict

init_weights(pretrained=None)

Init weights for models.

We just use the initialization method proposed in the original paper.

Parameters

pretrained (str, optional) – Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None.

convert_to_fp16()

Convert the precision of the model to float16.

convert_to_fp32()

Convert the precision of the model to float32.

class mmedit.models.editors.ContextualAttentionModule(unfold_raw_kernel_size=4, unfold_raw_stride=2, unfold_raw_padding=1, unfold_corr_kernel_size=3, unfold_corr_stride=1, unfold_corr_dilation=1, unfold_corr_padding=1, scale=0.5, fuse_kernel_size=3, softmax_scale=10, return_attention_score=True)

Bases: torch.nn.Module

Contexture attention module.

The details of this module can be found in: Generative Image Inpainting with Contextual Attention

Parameters
  • unfold_raw_kernel_size (int) – Kernel size used in unfolding raw feature. Default: 4.

  • unfold_raw_stride (int) – Stride used in unfolding raw feature. Default: 2.

  • unfold_raw_padding (int) – Padding used in unfolding raw feature. Default: 1.

  • unfold_corr_kernel_size (int) – Kernel size used in unfolding context for computing correlation maps. Default: 3.

  • unfold_corr_stride (int) – Stride used in unfolding context for computing correlation maps. Default: 1.

  • unfold_corr_dilation (int) – Dilation used in unfolding context for computing correlation maps. Default: 1.

  • unfold_corr_padding (int) – Padding used in unfolding context for computing correlation maps. Default: 1.

  • scale (float) – The resale factor used in resize input features. Default: 0.5.

  • fuse_kernel_size (int) – The kernel size used in fusion module. Default: 3.

  • softmax_scale (float) – The scale factor for softmax function. Default: 10.

  • return_attention_score (bool) – If True, the attention score will be returned. Default: True.

forward(x, context, mask=None)

Forward Function.

Parameters
  • x (torch.Tensor) – Tensor with shape (n, c, h, w).

  • context (torch.Tensor) – Tensor with shape (n, c, h, w).

  • mask (torch.Tensor) – Tensor with shape (n, 1, h, w). Default: None.

Returns

Features after contextural attention.

Return type

tuple(torch.Tensor)

patch_correlation(x, kernel)

Calculate patch correlation.

Parameters
  • x (torch.Tensor) – Input tensor.

  • kernel (torch.Tensor) – Kernel tensor.

Returns

Tensor with shape of (n, l, h, w).

Return type

torch.Tensor

patch_copy_deconv(attention_score, context_filter)

Copy patches using deconv.

Parameters
  • attention_score (torch.Tensor) – Tensor with shape of (n, l , h, w).

  • context_filter (torch.Tensor) – Filter kernel.

Returns

Tensor with shape of (n, c, h, w).

Return type

torch.Tensor

fuse_correlation_map(correlation_map, h_unfold, w_unfold)

Fuse correlation map.

This operation is to fuse correlation map for increasing large consistent correlation regions.

The mechanism behind this op is simple and easy to understand. A standard ‘Eye’ matrix will be applied as a filter on the correlation map in horizontal and vertical direction.

The shape of input correlation map is (n, h_unfold*w_unfold, h, w). When adopting fusing, we will apply convolutional filter in the reshaped feature map with shape of (n, 1, h_unfold*w_fold, h*w).

A simple specification for horizontal direction is shown below:

       (h, (h, (h, (h,
        0)  1)  2)  3)  ...
(h, 0)
(h, 1)      1
(h, 2)          1
(h, 3)              1
...
calculate_unfold_hw(input_size, kernel_size=3, stride=1, dilation=1, padding=0)

Calculate (h, w) after unfolding.

The official implementation of unfold in pytorch will put the dimension (h, w) into L. Thus, this function is just to calculate the (h, w) according to the equation in: https://pytorch.org/docs/stable/nn.html#torch.nn.Unfold

calculate_overlap_factor(attention_score)

Calculate the overlap factor after applying deconv.

Parameters

attention_score (torch.Tensor) – The attention score with shape of (n, c, h, w).

Returns

The overlap factor will be returned.

Return type

torch.Tensor

mask_correlation_map(correlation_map, mask)

Add mask weight for correlation map.

Add a negative infinity number to the masked regions so that softmax function will result in ‘zero’ in those regions.

Parameters
  • correlation_map (torch.Tensor) – Correlation map with shape of (n, h_unfold*w_unfold, h_map, w_map).

  • mask (torch.Tensor) – Mask tensor with shape of (n, c, h, w). ‘1’ in the mask indicates masked region while ‘0’ indicates valid region.

Returns

Updated correlation map with mask.

Return type

torch.Tensor

im2col(img, kernel_size, stride=1, padding=0, dilation=1, normalize=False, return_cols=False)

Reshape image-style feature to columns.

This function is used for unfold feature maps to columns. The details of this function can be found in: https://pytorch.org/docs/1.1.0/nn.html?highlight=unfold#torch.nn.Unfold

Parameters
  • img (torch.Tensor) – Features to be unfolded. The shape of this feature should be (n, c, h, w).

  • kernel_size (int) – In this function, we only support square kernel with same height and width.

  • stride (int) – Stride number in unfolding. Default: 1.

  • padding (int) – Padding number in unfolding. Default: 0.

  • dilation (int) – Dilation number in unfolding. Default: 1.

  • normalize (bool) – If True, the unfolded feature will be normalized. Default: False.

  • return_cols (bool) – The official implementation in PyTorch of unfolding will return features with shape of (n, c*$prod{kernel_size}$, L). If True, the features will be reshaped to (n, L, c, kernel_size, kernel_size). Otherwise, the results will maintain the shape as the official implementation.

Returns

Unfolded columns. If return_cols is True, the shape of output tensor is (n, L, c, kernel_size, kernel_size). Otherwise, the shape will be (n, c*$prod{kernel_size}$, L).

Return type

torch.Tensor

class mmedit.models.editors.ContextualAttentionNeck(in_channels, conv_type='conv', conv_cfg=None, norm_cfg=None, act_cfg=dict(type='ELU'), contextual_attention_args=dict(softmax_scale=10.0), **kwargs)

Bases: torch.nn.Module

Neck with contextual attention module.

Parameters
  • in_channels (int) – The number of input channels.

  • conv_type (str) – The type of conv module. In DeepFillv1 model, the conv_type should be ‘conv’. In DeepFillv2 model, the conv_type should be ‘gated_conv’.

  • conv_cfg (dict | None) – Config of conv module. Default: None.

  • norm_cfg (dict | None) – Config of norm module. Default: None.

  • act_cfg (dict | None) – Config of activation layer. Default: dict(type=’ELU’).

  • contextual_attention_args (dict) – Config of contextual attention module. Default: dict(softmax_scale=10.).

  • kwargs (keyword arguments) –

_conv_type
forward(x, mask)

Forward Function.

Parameters
  • x (torch.Tensor) – Input tensor with shape of (n, c, h, w).

  • mask (torch.Tensor) – Input tensor with shape of (n, 1, h, w).

Returns

Output tensor with shape of (n, c, h’, w’).

Return type

torch.Tensor

class mmedit.models.editors.DeepFillDecoder(in_channels, conv_type='conv', norm_cfg=None, act_cfg=dict(type='ELU'), out_act_cfg=dict(type='clip', min=- 1.0, max=1.0), channel_factor=1.0, **kwargs)

Bases: torch.nn.Module

Decoder used in DeepFill model.

This implementation follows: Generative Image Inpainting with Contextual Attention

Parameters
  • in_channels (int) – The number of input channels.

  • conv_type (str) – The type of conv module. In DeepFillv1 model, the conv_type should be ‘conv’. In DeepFillv2 model, the conv_type should be ‘gated_conv’.

  • norm_cfg (dict) – Config dict to build norm layer. Default: None.

  • act_cfg (dict) – Config dict for activation layer, “elu” by default.

  • out_act_cfg (dict) – Config dict for output activation layer. Here, we provide commonly used clamp or clip operation.

  • channel_factor (float) – The scale factor for channel size. Default: 1.

  • kwargs (keyword arguments) –

_conv_type
forward(input_dict)

Forward Function.

Parameters

input_dict (dict | torch.Tensor) – Input dict with middle features or torch.Tensor.

Returns

Output tensor with shape of (n, c, h, w).

Return type

torch.Tensor

class mmedit.models.editors.DeepFillEncoder(in_channels=5, conv_type='conv', norm_cfg=None, act_cfg=dict(type='ELU'), encoder_type='stage1', channel_factor=1.0, **kwargs)

Bases: torch.nn.Module

Encoder used in DeepFill model.

This implementation follows: Generative Image Inpainting with Contextual Attention

Parameters
  • in_channels (int) – The number of input channels. Default: 5.

  • conv_type (str) – The type of conv module. In DeepFillv1 model, the conv_type should be ‘conv’. In DeepFillv2 model, the conv_type should be ‘gated_conv’.

  • norm_cfg (dict) – Config dict to build norm layer. Default: None.

  • act_cfg (dict) – Config dict for activation layer, “elu” by default.

  • encoder_type (str) – Type of the encoder. Should be one of [‘stage1’, ‘stage2_conv’, ‘stage2_attention’]. Default: ‘stage1’.

  • channel_factor (float) – The scale factor for channel size. Default: 1.

  • kwargs (keyword arguments) –

_conv_type
forward(x)

Forward Function.

Parameters

x (torch.Tensor) – Input tensor with shape of (n, c, h, w).

Returns

Output tensor with shape of (n, c, h’, w’).

Return type

torch.Tensor

class mmedit.models.editors.DeepFillRefiner(encoder_attention=dict(type='DeepFillEncoder', encoder_type='stage2_attention'), encoder_conv=dict(type='DeepFillEncoder', encoder_type='stage2_conv'), dilation_neck=dict(type='GLDilationNeck', in_channels=128, act_cfg=dict(type='ELU')), contextual_attention=dict(type='ContextualAttentionNeck', in_channels=128), decoder=dict(type='DeepFillDecoder', in_channels=256))

Bases: torch.nn.Module

Refiner used in DeepFill model.

This implementation follows: Generative Image Inpainting with Contextual Attention.

Parameters
  • encoder_attention (dict) – Config dict for encoder used in branch with contextual attention module.

  • encoder_conv (dict) – Config dict for encoder used in branch with just convolutional operation.

  • dilation_neck (dict) – Config dict for dilation neck in branch with just convolutional operation.

  • contextual_attention (dict) – Config dict for contextual attention neck.

  • decoder (dict) – Config dict for decoder used to fuse and decode features.

forward(x, mask)

Forward Function.

Parameters
  • x (torch.Tensor) – Input tensor with shape of (n, c, h, w).

  • mask (torch.Tensor) – Input tensor with shape of (n, 1, h, w).

Returns

Output tensor with shape of (n, c, h’, w’).

Return type

torch.Tensor

class mmedit.models.editors.DeepFillv1Discriminators(global_disc_cfg, local_disc_cfg)

Bases: torch.nn.Module

Discriminators used in DeepFillv1 model.

In DeepFillv1 model, the discriminators are independent without any concatenation like Global&Local model. Thus, we call this model DeepFillv1Discriminators. There exist a global discriminator and a local discriminator with global and local input respectively.

The details can be found in: Generative Image Inpainting with Contextual Attention.

Parameters
  • global_disc_cfg (dict) – Config dict for global discriminator.

  • local_disc_cfg (dict) – Config dict for local discriminator.

forward(x)

Forward function.

Parameters

x (tuple[torch.Tensor]) – Contains global image and the local image patch.

Returns

Contains the prediction from discriminators in global image and local image patch.

Return type

tuple[torch.Tensor]

init_weights(pretrained=None)

Init weights for models.

Parameters

pretrained (str, optional) – Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None.

class mmedit.models.editors.DeepFillv1Inpaintor(data_preprocessor: dict, encdec: dict, disc=None, loss_gan=None, loss_gp=None, loss_disc_shift=None, loss_composed_percep=None, loss_out_percep=False, loss_l1_hole=None, loss_l1_valid=None, loss_tv=None, stage1_loss_type=None, stage2_loss_type=None, train_cfg=None, test_cfg=None, init_cfg: Optional[dict] = None)

Bases: mmedit.models.base_models.TwoStageInpaintor

Inpaintor for deepfillv1 method.

This inpaintor is implemented according to the paper: Generative image inpainting with contextual attention

Importantly, this inpaintor is an example for using custom training schedule based on TwoStageInpaintor.

The training pipeline of deepfillv1 is as following:

if cur_iter < iter_tc:
    update generator with only l1 loss
else:
    update discriminator
    if cur_iter > iter_td:
        update generator with l1 loss and adversarial loss

The new attribute cur_iter is added for recording current number of iteration. The train_cfg contains the setting of the training schedule:

train_cfg = dict(
    start_iter=0,
    disc_step=1,
    iter_tc=90000,
    iter_td=100000
)

iter_tc and iter_td correspond to the notation \(T_C\) and \(T_D\) of the original paper.

Parameters
  • generator (dict) – Config for encoder-decoder style generator.

  • disc (dict) – Config for discriminator.

  • loss_gan (dict) – Config for adversarial loss.

  • loss_gp (dict) – Config for gradient penalty loss.

  • loss_disc_shift (dict) – Config for discriminator shift loss.

  • loss_composed_percep (dict) – Config for perceptual and style loss with composed image as input.

  • loss_out_percep (dict) – Config for perceptual and style loss with direct output as input.

  • loss_l1_hole (dict) – Config for l1 loss in the hole.

  • loss_l1_valid (dict) – Config for l1 loss in the valid region.

  • loss_tv (dict) – Config for total variation loss.

  • train_cfg (dict) – Configs for training scheduler. disc_step must be contained for indicates the discriminator updating steps in each training step.

  • test_cfg (dict) – Configs for testing scheduler.

  • init_cfg (dict, optional) – Initialization config dict.

forward_train_d(data_batch, is_real, is_disc)

Forward function in discriminator training step.

In this function, we modify the default implementation with only one discriminator. In DeepFillv1 model, they use two separated discriminators for global and local consistency.

Parameters
  • data_batch (torch.Tensor) – Batch of real data or fake data.

  • is_real (bool) – If True, the gan loss will regard this batch as real data. Otherwise, the gan loss will regard this batch as fake data.

  • is_disc (bool) – If True, this function is called in discriminator training step. Otherwise, this function is called in generator training step. This will help us to compute different types of adversarial loss, like LSGAN.

Returns

Contains the loss items computed in this function.

Return type

dict

two_stage_loss(stage1_data, stage2_data, gt, mask, masked_img)

Calculate two-stage loss.

Parameters
  • stage1_data (dict) – Contain stage1 results.

  • stage2_data (dict) – Contain stage2 results.

  • gt (torch.Tensor) – Ground-truth image.

  • mask (torch.Tensor) – Mask image.

  • masked_img (torch.Tensor) – Composition of mask image and ground-truth image.

Returns

Dict contains the results computed within this function for visualization and dict contains the loss items computed in this function.

Return type

tuple(dict)

calculate_loss_with_type(loss_type, fake_res, fake_img, gt, mask, prefix='stage1_', fake_local=None)

Calculate multiple types of losses.

Parameters
  • loss_type (str) – Type of the loss.

  • fake_res (torch.Tensor) – Direct results from model.

  • fake_img (torch.Tensor) – Composited results from model.

  • gt (torch.Tensor) – Ground-truth tensor.

  • mask (torch.Tensor) – Mask tensor.

  • prefix (str, optional) – Prefix for loss name. Defaults to ‘stage1_’. # noqa

  • fake_local (torch.Tensor, optional) – Local results from model. Defaults to None.

Returns

Contain loss value with its name.

Return type

dict

train_step(data: List[dict], optim_wrapper)

Train step function.

In this function, the inpaintor will finish the train step following the pipeline:

  1. get fake res/image

  2. optimize discriminator (if have)

  3. optimize generator

If self.train_cfg.disc_step > 1, the train step will contain multiple iterations for optimizing discriminator with different input data and only one iteration for optimizing gerator after disc_step iterations for discriminator.

Parameters
  • data (List[dict]) – Batch of data as input.

  • optim_wrapper (dict[torch.optim.Optimizer]) – Dict with optimizers for generator and discriminator (if have).

Returns

Dict with loss, information for logger, the number of samples and results for visualization.

Return type

dict

class mmedit.models.editors.DeepFillEncoderDecoder(stage1=dict(type='GLEncoderDecoder', encoder=dict(type='DeepFillEncoder'), decoder=dict(type='DeepFillDecoder', in_channels=128), dilation_neck=dict(type='GLDilationNeck', in_channels=128, act_cfg=dict(type='ELU'))), stage2=dict(type='DeepFillRefiner'), return_offset=False)

Bases: torch.nn.Module

Two-stage encoder-decoder structure used in DeepFill model.

The details are in: Generative Image Inpainting with Contextual Attention

Parameters
  • stage1 (dict) – Config dict for building stage1 model. As DeepFill model uses Global&Local model as baseline in first stage, the stage1 model can be easily built with GLEncoderDecoder.

  • stage2 (dict) – Config dict for building stage2 model.

  • return_offset (bool) – Whether to return offset feature in contextual attention module. Default: False.

forward(x)

Forward function.

Parameters

x (torch.Tensor) – This input tensor has the shape of (n, 5, h, w). In channel dimension, we concatenate [masked_img, ones, mask] as DeepFillv1 models do.

Returns

The first two item is the results from first and second stage. If set return_offset as True, the offset will be returned as the third item.

Return type

tuple[torch.Tensor]

init_weights(pretrained=None)

Init weights for models.

Parameters

pretrained (str, optional) – Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None.

class mmedit.models.editors.DIC(generator, pixel_loss, align_loss, discriminator=None, gan_loss=None, feature_loss=None, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None)

Bases: mmedit.models.editors.srgan.SRGAN

DIC model for Face Super-Resolution.

Paper: Deep Face Super-Resolution with Iterative Collaboration between

Attentive Recovery and Landmark Estimation.

Parameters
  • generator (dict) – Config for the generator.

  • pixel_loss (dict) – Config for the pixel loss.

  • align_loss (dict) – Config for the align loss.

  • discriminator (dict) – Config for the discriminator. Default: None.

  • gan_loss (dict) – Config for the gan loss. Default: None.

  • feature_loss (dict) – Config for the feature loss. Default: None.

  • train_cfg (dict) – Config for train. Default: None.

  • test_cfg (dict) – Config for testing. Default: None.

  • init_cfg (dict, optional) – The weight initialized config for BaseModule. Default: None.

  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor. Default: None.

forward_tensor(inputs, data_samples=None, training=False)

Forward tensor. Returns result of simple forward.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

  • training (bool) – Whether is training. Default: False.

Returns

results of forward inference and

forward train.

Return type

(Tensor | Tuple[List[Tensor]])

if_run_g()

Calculates whether need to run the generator step.

if_run_d()

Calculates whether need to run the discriminator step.

g_step(batch_outputs, batch_gt_data)

G step of GAN: Calculate losses of generator.

Parameters
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

Returns

Dict of losses.

Return type

dict

d_step_with_optim(batch_outputs, batch_gt_data, optim_wrapper)

D step with optim of GAN: Calculate losses of discriminator and run optim.

Parameters
  • batch_outputs (Tuple[Tensor]) – Batch output of generator.

  • batch_gt_data (Tuple[Tensor]) – Batch GT data.

  • optim_wrapper (OptimWrapper) – Optim wrapper of discriminator.

Returns

Dict of parsed losses.

Return type

dict

static extract_gt_data(data_samples)

extract gt data from data samples.

Parameters

data_samples (list) – List of EditDataSample.

Returns

Extract gt data.

Return type

Tensor

class mmedit.models.editors.DICNet(in_channels, out_channels, mid_channels, num_blocks=6, hg_mid_channels=256, hg_num_keypoints=68, num_steps=4, upscale_factor=8, detach_attention=False, prelu_init=0.2, num_heatmaps=5, num_fusion_blocks=7)

Bases: mmengine.model.BaseModule

DIC network structure for face super-resolution.

Paper: Deep Face Super-Resolution with Iterative Collaboration between

Attentive Recovery and Landmark Estimation

Parameters
  • in_channels (int) – Number of channels in the input image

  • out_channels (int) – Number of channels in the output image

  • mid_channels (int) – Channel number of intermediate features. Default: 64

  • num_blocks (tuple[int]) – Block numbers in the trunk network. Default: 6

  • hg_mid_channels (int) – Channel number of intermediate features of HourGlass. Default: 256

  • hg_num_keypoints (int) – Keypoint number of HourGlass. Default: 68

  • num_steps (int) – Number of iterative steps. Default: 4

  • upscale_factor (int) – Upsampling factor. Default: 8

  • detach_attention (bool) – Detached from the current tensor for heatmap or not.

  • prelu_init (float) – init of PReLU. Default: 0.2

  • num_heatmaps (int) – Number of heatmaps. Default: 5

  • num_fusion_blocks (int) – Number of fusion blocks. Default: 7

forward(x)

Forward function.

Parameters

x (Tensor) – Input tensor.

Returns

Forward results. sr_outputs (list[Tensor]): forward sr results. heatmap_outputs (list[Tensor]): forward heatmap results.

Return type

Tensor

class mmedit.models.editors.FeedbackBlock(mid_channels, num_blocks, upscale_factor, padding=2, prelu_init=0.2)

Bases: torch.nn.Module

Feedback Block of DIC.

It has a style of:

----- Module ----->
  ^            |
  |____________|
Parameters
  • mid_channels (int) – Number of channels in the intermediate features.

  • num_blocks (int) – Number of blocks.

  • upscale_factor (int) – upscale factor.

  • padding (int) – Padding size. Default: 2.

  • prelu_init (float) – init of PReLU. Default: 0.2

forward(x)

Forward function.

Parameters

x (Tensor) – Input tensor with shape (n, c, h, w).

Returns

Forward results.

Return type

Tensor

class mmedit.models.editors.FeedbackBlockCustom(in_channels, mid_channels, num_blocks, upscale_factor)

Bases: FeedbackBlock

Custom feedback block, will be used as the first feedback block.

Parameters
  • in_channels (int) – Number of channels in the input features.

  • mid_channels (int) – Number of channels in the intermediate features.

  • num_blocks (int) – Number of blocks.

  • upscale_factor (int) – upscale factor.

forward(x)

Forward function.

Parameters

x (Tensor) – Input tensor with shape (n, c, h, w).

Returns

Forward results.

Return type

Tensor

class mmedit.models.editors.FeedbackBlockHeatmapAttention(mid_channels, num_blocks, upscale_factor, num_heatmaps, num_fusion_blocks, padding=2, prelu_init=0.2)

Bases: FeedbackBlock

Feedback block with HeatmapAttention.

Parameters
  • in_channels (int) – Number of channels in the input features.

  • mid_channels (int) – Number of channels in the intermediate features.

  • num_blocks (int) – Number of blocks.

  • upscale_factor (int) – upscale factor.

  • padding (int) – Padding size. Default: 2.

  • prelu_init (float) – init of PReLU. Default: 0.2

forward(x, heatmap)

Forward function.

Parameters
  • x (Tensor) – Input feature tensor.

  • heatmap (Tensor) – Input heatmap tensor.

Returns

Forward results.

Return type

Tensor

class mmedit.models.editors.LightCNN(in_channels)

Bases: mmengine.model.BaseModule

LightCNN discriminator with input size 128 x 128.

It is used to train DICGAN.

Parameters

in_channels (int) – Channel number of inputs.

forward(x)

Forward function.

Parameters

x (Tensor) – Input tensor.

Returns

Forward results.

Return type

Tensor

init_weights(pretrained=None, strict=True)

Init weights for models.

Parameters
  • pretrained (str, optional) – Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None.

  • strict (boo, optional) – Whether strictly load the pretrained model. Defaults to True.

class mmedit.models.editors.MaxFeature(in_channels, out_channels, kernel_size=3, stride=1, padding=1, filter_type='conv2d')

Bases: torch.nn.Module

Conv2d or Linear layer with max feature selector.

Generate feature maps with double channels, split them and select the max

feature.

Parameters
  • in_channels (int) – Channel number of inputs.

  • out_channels (int) – Channel number of outputs.

  • kernel_size (int or tuple) – Size of the convolving kernel.

  • stride (int or tuple, optional) – Stride of the convolution. Default: 1

  • padding (int or tuple, optional) – Zero-padding added to both sides of the input. Default: 1

  • filter_type (str) – Type of filter. Options are ‘conv2d’ and ‘linear’. Default: ‘conv2d’.

forward(x)

Forward function.

Parameters

x (Tensor) – Input tensor.

Returns

Forward results.

Return type

Tensor

class mmedit.models.editors.DIM(data_preprocessor, backbone, refiner=None, train_cfg=None, test_cfg=None, loss_alpha=None, loss_comp=None, loss_refine=None, init_cfg: Optional[dict] = None)

Bases: mmedit.models.base_models.BaseMattor

Deep Image Matting model.

https://arxiv.org/abs/1703.03872

Note

For (self.train_cfg.train_backbone, self.train_cfg.train_refiner):

  • (True, False) corresponds to the encoder-decoder stage in the paper.

  • (False, True) corresponds to the refinement stage in the paper.

  • (True, True) corresponds to the fine-tune stage in the paper.

Parameters
  • data_preprocessor (dict, optional) – Config of data pre-processor.

  • backbone (dict) – Config of backbone.

  • refiner (dict) – Config of refiner.

  • loss_alpha (dict) – Config of the alpha prediction loss. Default: None.

  • loss_comp (dict) – Config of the composition loss. Default: None.

  • loss_refine (dict) – Config of the loss of the refiner. Default: None.

  • train_cfg (dict) – Config of training. In train_cfg, train_backbone should be specified. If the model has a refiner, train_refiner should be specified.

  • test_cfg (dict) – Config of testing. In test_cfg, If the model has a refiner, train_refiner should be specified.

  • init_cfg (dict, optional) – The weight initialized config for BaseModule. Default: None.

property with_refiner

Whether the matting model has a refiner.

init_weights()

Initialize the model network weights.

train(mode=True)

Mode switcher.

Parameters

mode (bool) – whether to set training mode (True) or evaluation mode (False). Default: True.

freeze_backbone()

Freeze the backbone and only train the refiner.

_forward(x: torch.Tensor, *, refine: bool = True) Tuple[torch.Tensor, torch.Tensor]

Raw forward function.

Parameters
  • x (torch.Tensor) – Concatenation of merged image and trimap with shape (N, 4, H, W)

  • refine (bool) – if forward through refiner

Returns

pred_alpha, with shape (N, 1, H, W) torch.Tensor: pred_refine, with shape (N, 4, H, W)

Return type

torch.Tensor

_forward_test(inputs)

Forward to get alpha prediction.

_forward_train(inputs, data_samples)

Defines the computation performed at every training call.

Parameters
  • inputs (torch.Tensor) – Concatenation of normalized image and trimap shape (N, 4, H, W)

  • data_samples (list[EditDataSample]) –

    Data samples containing: - gt_alpha (Tensor): Ground-truth of alpha

    shape (N, 1, H, W), normalized to 0 to 1.

    • gt_fg (Tensor): Ground-truth of foreground

      shape (N, C, H, W), normalized to 0 to 1.

    • gt_bg (Tensor): Ground-truth of background

      shape (N, C, H, W), normalized to 0 to 1.

Returns

Contains the loss items and batch information.

Return type

dict

class mmedit.models.editors.ClipWrapper(clip_type, *args, **kwargs)

Bases: torch.nn.Module

Clip Models wrapper for disco-diffusion.

We provide wrappers for the clip models of openai and mlfoundations, where the user can specify clip_type as clip or open_clip, and then initialize a clip model using the same arguments as in the original codebase. The following clip models settings are provided in the official repo of disco diffusion:

Setting | Source | Arguments | # noqa

|:-----------------------------:|———–|--------------------------------------------------------------| # noqa | ViTB32 | clip | name=’ViT-B/32’, jit=False | # noqa | ViTB16 | clip | name=’ViT-B/16’, jit=False | # noqa | ViTL14 | clip | name=’ViT-L/14’, jit=False | # noqa | ViTL14_336px | clip | name=’ViT-L/14@336px’, jit=False | # noqa | RN50 | clip | name=’RN50’, jit=False | # noqa | RN50x4 | clip | name=’RN50x4’, jit=False | # noqa | RN50x16 | clip | name=’RN50x16’, jit=False | # noqa | RN50x64 | clip | name=’RN50x64’, jit=False | # noqa | RN101 | clip | name=’RN101’, jit=False | # noqa | ViTB32_laion2b_e16 | open_clip | name=’ViT-B-32’, pretrained=’laion2b_e16’ | # noqa | ViTB32_laion400m_e31 | open_clip | model_name=’ViT-B-32’, pretrained=’laion400m_e31’ | # noqa | ViTB32_laion400m_32 | open_clip | model_name=’ViT-B-32’, pretrained=’laion400m_e32’ | # noqa | ViTB32quickgelu_laion400m_e31 | open_clip | model_name=’ViT-B-32-quickgelu’, pretrained=’laion400m_e31’ | # noqa | ViTB32quickgelu_laion400m_e32 | open_clip | model_name=’ViT-B-32-quickgelu’, pretrained=’laion400m_e32’ | # noqa | ViTB16_laion400m_e31 | open_clip | model_name=’ViT-B-16’, pretrained=’laion400m_e31’ | # noqa | ViTB16_laion400m_e32 | open_clip | model_name=’ViT-B-16’, pretrained=’laion400m_e32’ | # noqa | RN50_yffcc15m | open_clip | model_name=’RN50’, pretrained=’yfcc15m’ | # noqa | RN50_cc12m | open_clip | model_name=’RN50’, pretrained=’cc12m’ | # noqa | RN50_quickgelu_yfcc15m | open_clip | model_name=’RN50-quickgelu’, pretrained=’yfcc15m’ | # noqa | RN50_quickgelu_cc12m | open_clip | model_name=’RN50-quickgelu’, pretrained=’cc12m’ | # noqa | RN101_yfcc15m | open_clip | model_name=’RN101’, pretrained=’yfcc15m’ | # noqa | RN101_quickgelu_yfcc15m | open_clip | model_name=’RN101-quickgelu’, pretrained=’yfcc15m’ | # noqa

An example of a clip_modes_cfg is as follows: .. code-block:: python

clip_models = [

dict(type=’ClipWrapper’, clip_type=’clip’, name=’ViT-B/32’, jit=False), dict(type=’ClipWrapper’, clip_type=’clip’, name=’ViT-B/16’, jit=False), dict(type=’ClipWrapper’, clip_type=’clip’, name=’RN50’, jit=False)

]

Parameters

clip_type (List[Dict]) – The original source of the clip model. Whether be clip or open_clip.

forward(*args, **kwargs)

Forward function.

class mmedit.models.editors.DiscoDiffusion(unet, diffusion_scheduler, secondary_model=None, clip_models=[], use_fp16=False, pretrained_cfgs=None)

Bases: torch.nn.Module

Disco Diffusion (DD) is a Google Colab Notebook which leverages an AI Image generating technique called CLIP-Guided Diffusion to allow you to create compelling and beautiful images from just text inputs. Created by Somnai, augmented by Gandamu, and building on the work of RiversHaveWings, nshepperd, and many others.

Ref:

Github Repo: https://github.com/alembics/disco-diffusion Colab: https://colab.research.google.com/github/alembics/disco-diffusion/blob/main/Disco_Diffusion.ipynb # noqa

Parameters
  • unet (ModelType) – Config of denoising Unet.

  • diffusion_scheduler (ModelType) – Config of diffusion_scheduler scheduler.

  • secondary_model (ModelType) – A smaller secondary diffusion model trained by Katherine Crowson to remove noise from intermediate timesteps to prepare them for CLIP. Ref: https://twitter.com/rivershavewings/status/1462859669454536711 # noqa Defaults to None.

  • clip_models (list) – Config of clip models. Defaults to [].

  • use_fp16 (bool) – Whether to use fp16 for unet model. Defaults to False.

  • pretrained_cfgs (dict) – Path Config for pretrained weights. Usually this is a dict contains module name and the corresponding ckpt path. Defaults to None.

property device

Get current device of the model.

Returns

The current device of the model.

Return type

torch.device

load_pretrained_models(pretrained_cfgs)

Loading pretrained weights to model. pretrained_cfgs is a dict consist of module name as key and checkpoint path as value.

Parameters
  • pretrained_cfgs (dict) – Path Config for pretrained weights.

  • the (Usually this is a dict contains module name and) –

  • None. (corresponding ckpt path. Defaults to) –

infer(scheduler_kwargs=None, height=None, width=None, init_image=None, batch_size=1, num_inference_steps=1000, skip_steps=0, show_progress=False, text_prompts=[], image_prompts=[], eta=0.8, clip_guidance_scale=5000, init_scale=1000, tv_scale=0.0, sat_scale=0.0, range_scale=150, cut_overview=[12] * 400 + [4] * 600, cut_innercut=[4] * 400 + [12] * 600, cut_ic_pow=[1] * 1000, cut_icgray_p=[0.2] * 400 + [0] * 600, cutn_batches=4, seed=None)

Inference API for disco diffusion.

Parameters
  • scheduler_kwargs (dict) – Args for infer time diffusion scheduler. Defaults to None.

  • height (int) – Height of output image. Defaults to None.

  • width (int) – Width of output image. Defaults to None.

  • init_image (str) – Initial image at the start point of denoising. Defaults to None.

  • batch_size (int) – Batch size. Defaults to 1.

  • num_inference_steps (int) – Number of inference steps. Defaults to 1000.

  • skip_steps (int) – Denoising steps to skip, usually set with init_image. Defaults to 0.

  • show_progress (bool) – Whether to show progress. Defaults to False.

  • text_prompts (list) – Text prompts. Defaults to [].

  • image_prompts (list) – Image prompts, this is not the same as init_image, they works the same way with text_prompts. Defaults to [].

  • eta (float) – Eta for ddim sampling. Defaults to 0.8.

  • clip_guidance_scale (int) – The Scale of influence of prompts on output image. Defaults to 1000.

  • seed (int) – Sampling seed. Defaults to None.

class mmedit.models.editors.EDSRNet(in_channels, out_channels, mid_channels=64, num_blocks=16, upscale_factor=4, res_scale=1, rgb_mean=[0.4488, 0.4371, 0.404], rgb_std=[1.0, 1.0, 1.0])

Bases: mmengine.model.BaseModule

EDSR network structure.

Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution. Ref repo: https://github.com/thstkdgus35/EDSR-PyTorch

Parameters
  • in_channels (int) – Channel number of inputs.

  • out_channels (int) – Channel number of outputs.

  • mid_channels (int) – Channel number of intermediate features. Default: 64.

  • num_blocks (int) – Block number in the trunk network. Default: 16.

  • upscale_factor (int) – Upsampling factor. Support 2^n and 3. Default: 4.

  • res_scale (float) – Used to scale the residual in residual block. Default: 1.

  • rgb_mean (list[float]) – Image mean in RGB orders. Default: [0.4488, 0.4371, 0.4040], calculated from DIV2K dataset.

  • rgb_std (list[float]) – Image std in RGB orders. In EDSR, it uses [1.0, 1.0, 1.0]. Default: [1.0, 1.0, 1.0].

forward(x)

Forward function.

Parameters

x (Tensor) – Input tensor with shape (n, c, h, w).

Returns

Forward results.

Return type

Tensor

class mmedit.models.editors.EDVR(generator, pixel_loss, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None)

Bases: mmedit.models.BaseEditModel

EDVR model for video super-resolution.

EDVR: Video Restoration with Enhanced Deformable Convolutional Networks.

Parameters
  • generator (dict) – Config for the generator structure.

  • pixel_loss (dict) – Config for pixel-wise loss.

  • train_cfg (dict) – Config for training. Default: None.

  • test_cfg (dict) – Config for testing. Default: None.

  • init_cfg (dict, optional) – The weight initialized config for BaseModule.

  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor.

forward_train(inputs, data_samples=None)

Forward training. Returns dict of losses of training.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

Returns

Dict of losses.

Return type

dict

class mmedit.models.editors.EDVRNet(in_channels, out_channels, mid_channels=64, num_frames=5, deform_groups=8, num_blocks_extraction=5, num_blocks_reconstruction=10, center_frame_idx=2, with_tsa=True, init_cfg=None)

Bases: mmengine.model.BaseModule

EDVR network structure for video super-resolution.

Now only support X4 upsampling factor. Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks.

Parameters
  • in_channels (int) – Channel number of inputs.

  • out_channels (int) – Channel number of outputs.

  • mid_channels (int) – Channel number of intermediate features. Default: 64.

  • num_frames (int) – Number of input frames. Default: 5.

  • deform_groups (int) – Deformable groups. Defaults: 8.

  • num_blocks_extraction (int) – Number of blocks for feature extraction. Default: 5.

  • num_blocks_reconstruction (int) – Number of blocks for reconstruction. Default: 10.

  • center_frame_idx (int) – The index of center frame. Frame counting from 0. Default: 2.

  • with_tsa (bool) – Whether to use TSA module. Default: True.

  • init_cfg (dict, optional) – Initialization config dict. Default: None.

forward(x)

Forward function for EDVRNet.

Parameters

x (Tensor) – Input tensor with shape (n, t, c, h, w).

Returns

SR center frame with shape (n, c, h, w).

Return type

Tensor

init_weights()

Init weights for models.

class mmedit.models.editors.EG3D(generator: ModelType, discriminator: Optional[ModelType] = None, camera: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, noise_size: Optional[int] = None, ema_config: Optional[Dict] = None, loss_config: Optional[Dict] = None)

Bases: mmedit.models.base_models.BaseConditionalGAN

Implementation of Efficient Geometry-aware 3D Generative Adversarial Networks

<https://openaccess.thecvf.com/content/CVPR2022/papers/Chan_Efficient_Geometry-Aware_3D_Generative_Adversarial_Networks_CVPR_2022_paper.pdf>_ (EG3D). # noqa

Detailed architecture can be found in :class:~`mmedit.models.editors.eg3d.eg3d_generator.TriplaneGenerator` and :class:~`mmedit.models.editors.eg3d.dual_discriminator.DualDiscriminator`

Parameters
  • generator (ModelType) – The config or model of the generator.

  • discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.

  • camera (Optional[ModelType]) – The pre-defined camera to sample random camera position. If you want to generate images or videos via high-level API, you must set this argument. Defaults to None.

  • data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or GenDataPreprocessor.

  • generator_steps (int) – Number of times the generator was completely updated before the discriminator is updated. Defaults to 1.

  • discriminator_steps (int) – Number of times the discriminator was completely updated before the generator is updated. Defaults to 1.

  • noise_size (Optional[int]) – Size of the input noise vector. Default to 128.

  • num_classes (Optional[int]) – The number classes you would like to generate. Defaults to None.

  • ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.

  • loss_config (Optional[Dict]) – The config for training losses. Defaults to None.

label_fn(label: Optional[torch.Tensor] = None, num_batches: int = 1) torch.Tensor

Label sampling function for EG3D model.

Parameters

label (Optional[Tensor]) – Conditional for EG3D model. If not passed, self.camera will be used to sample random camera-to-world and intrinsics matrix. Defaults to None.

Returns

Conditional input for EG3D model.

Return type

torch.Tensor

data_sample_to_label(data_sample: mmedit.utils.typing.SampleList) Optional[torch.Tensor]

Get labels from input data_sample and pack to torch.Tensor. If no label is found in the passed data_sample, None would be returned.

Parameters

data_sample (List[EditDataSample]) – Input data samples.

Returns

Packed label tensor.

Return type

Optional[torch.Tensor]

pack_to_data_sample(output: Dict[str, torch.Tensor], index: int, data_sample: Optional[mmedit.structures.EditDataSample] = None) mmedit.structures.EditDataSample

Pack output to data sample. If data_sample is not passed, a new EditDataSample will be instantiated. Otherwise, outputs will be added to the passed datasample.

Parameters
  • output (Dict[Tensor]) – Output of the model.

  • index (int) – The index to save.

  • data_sample (EditDataSample, optional) – Data sample to save outputs. Defaults to None.

Returns

Data sample with packed outputs.

Return type

EditDataSample

forward(inputs: mmedit.utils.typing.ForwardInputs, data_samples: Optional[list] = None, mode: Optional[str] = None) List[mmedit.structures.EditDataSample]

Sample images with the given inputs. If forward mode is ‘ema’ or ‘orig’, the image generated by corresponding generator will be returned. If forward mode is ‘ema/orig’, images generated by original generator and EMA generator will both be returned in a dict.

Parameters
  • inputs (ForwardInputs) – Dict containing the necessary information (e.g. noise, num_batches, mode) to generate image.

  • data_samples (Optional[list]) – Data samples collated by data_preprocessor. Defaults to None.

  • mode (Optional[str]) – mode is not used in BaseConditionalGAN. Defaults to None.

Returns

Generated images or image dict.

Return type

List[EditDataSample]

interpolation(num_images: int, num_batches: int = 4, mode: str = 'both', sample_model: str = 'orig', show_pbar: bool = True) List[dict]

Interpolation input and return a list of output results. We support three kinds of interpolation mode:

  • ‘camera’: First generate style code with random noise and forward

    camera. Then synthesis images with interpolated camera position and fixed style code.

  • ‘conditioning’: First generate style code with fixed noise and

    interpolated camera. Then synthesis images with style codes and forward camera.

  • ‘both’: Generate images with interpolated camera position.

Parameters
  • num_images (int) – The number of images want to generate.

  • num_batches (int, optional) – The number of batches to generate at one time. Defaults to 4.

  • mode (str, optional) – The interpolation mode. Supported choices are ‘both’, ‘camera’, and ‘conditioning’. Defaults to ‘both’.

  • sample_model (str, optional) – The model used to generate images, support ‘orig’ and ‘ema’. Defaults to ‘orig’.

  • show_pbar (bool, optional) – Whether display a progress bar during interpolation. Defaults to True.

Returns

The list of output dict of each frame.

Return type

List[dict]

class mmedit.models.editors.ESRGAN(generator, discriminator=None, gan_loss=None, pixel_loss=None, perceptual_loss=None, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None)

Bases: mmedit.models.editors.srgan.SRGAN

Enhanced SRGAN model for single image super-resolution.

Ref: ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. It uses RaGAN for GAN updates: The relativistic discriminator: a key element missing from standard GAN.

Parameters
  • generator (dict) – Config for the generator.

  • discriminator (dict) – Config for the discriminator. Default: None.

  • gan_loss (dict) – Config for the gan loss. Note that the loss weight in gan loss is only for the generator.

  • pixel_loss (dict) – Config for the pixel loss. Default: None.

  • perceptual_loss (dict) – Config for the perceptual loss. Default: None.

  • train_cfg (dict) – Config for training. Default: None. You may change the training of gan by setting: disc_steps: how many discriminator updates after one generate update; disc_init_steps: how many discriminator updates at the start of the training. These two keys are useful when training with WGAN.

  • test_cfg (dict) – Config for testing. Default: None.

  • init_cfg (dict, optional) – The weight initialized config for BaseModule. Default: None.

g_step(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor)

G step of GAN: Calculate losses of generator.

Parameters
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

Returns

Dict of losses.

Return type

dict

d_step_real(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor)

D step of real data.

Parameters
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

Returns

Dict of losses.

Return type

dict

d_step_fake(batch_outputs: torch.Tensor, batch_gt_data)

D step of fake data.

Parameters
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

Returns

Dict of losses.

Return type

dict

class mmedit.models.editors.RRDBNet(in_channels, out_channels, mid_channels=64, num_blocks=23, growth_channels=32, upscale_factor=4, init_cfg=None)

Bases: mmengine.model.BaseModule

Networks consisting of Residual in Residual Dense Block, which is used in ESRGAN and Real-ESRGAN.

ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. # noqa: E501 Currently, it supports [x1/x2/x4] upsampling scale factor.

Parameters
  • in_channels (int) – Channel number of inputs.

  • out_channels (int) – Channel number of outputs.

  • mid_channels (int) – Channel number of intermediate features. Default: 64

  • num_blocks (int) – Block number in the trunk network. Defaults: 23

  • growth_channels (int) – Channels for each growth. Default: 32.

  • upscale_factor (int) – Upsampling factor. Support x1, x2 and x4. Default: 4.

_supported_upscale_factors = [1, 2, 4]
forward(x)

Forward function.

Parameters

x (Tensor) – Input tensor with shape (n, c, h, w).

Returns

Forward results.

Return type

Tensor

init_weights()

Init weights for models.

class mmedit.models.editors.FBADecoder(pool_scales, in_channels, channels, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), align_corners=False)

Bases: torch.nn.Module

Decoder for FBA matting.

Parameters
  • pool_scales (tuple[int]) – Pooling scales used in

  • Module. (Pooling Pyramid) –

  • in_channels (int) – Input channels.

  • channels (int) – Channels after modules, before conv_seg.

  • conv_cfg (dict|None) – Config of conv layers.

  • norm_cfg (dict|None) – Config of norm layers.

  • act_cfg (dict) – Config of activation layers.

  • align_corners (bool) – align_corners argument of F.interpolate.

init_weights(pretrained=None)

Init weights for the model.

Parameters

pretrained (str, optional) – Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None.

forward(inputs)

Forward function.

Parameters

inputs (dict) – Output dict of FbaEncoder.

Returns

Predicted alpha, fg and bg of the current batch.

Return type

tuple(Tensor)

class mmedit.models.editors.FBAResnetDilated(depth, in_channels, stem_channels, base_channels, num_stages=4, strides=(1, 2, 2, 2), dilations=(1, 1, 2, 4), deep_stem=False, avg_down=False, frozen_stages=- 1, act_cfg=dict(type='ReLU'), conv_cfg=None, norm_cfg=dict(type='BN'), with_cp=False, multi_grid=None, contract_dilation=False, zero_init_residual=True)

Bases: mmedit.models.base_archs.ResNet

ResNet-based encoder for FBA image matting.

forward(x)

Forward function.

Parameters

x (Tensor) – Input tensor with shape (N, C, H, W).

Returns

Output tensor.

Return type

Tensor

class mmedit.models.editors.FLAVR(generator, pixel_loss, train_cfg=None, test_cfg=None, required_frames=2, step_frames=1, init_cfg=None, data_preprocessor=None)

Bases: mmedit.models.base_models.BasicInterpolator

FLAVR model for video interpolation.

Paper:

FLAVR: Flow-Agnostic Video Representations for Fast Frame Interpolation

Ref repo: https://github.com/tarun005/FLAVR

Parameters
  • generator (dict) – Config for the generator structure.

  • pixel_loss (dict) – Config for pixel-wise loss.

  • train_cfg (dict) – Config for training. Default: None.

  • test_cfg (dict) – Config for testing. Default: None.

  • required_frames (int) – Required frames in each process. Default: 2

  • step_frames (int) – Step size of video frame interpolation. Default: 1

  • init_cfg (dict, optional) – The weight initialized config for BaseModule.

  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor.

init_cfg

Initialization config dict.

Type

dict, optional

data_preprocessor

Used for pre-processing data sampled by dataloader to the format accepted by forward().

Type

BaseDataPreprocessor

static merge_frames(input_tensors, output_tensors)

merge input frames and output frames.

Interpolate a frame between the given two frames.

Merged from

[[in1, in2, in3, in4], [in2, in3, in4, in5], …] [[out1], [out2], [out3], …]

to

[in1, in2, out1, in3, out2, …, in(-3), out(-1), in(-2), in(-1)]

Parameters
  • input_tensors (Tensor) – The input frames with shape [n, 4, c, h, w]

  • output_tensors (Tensor) – The output frames with shape [n, 1, c, h, w].

Returns

The final frames.

Return type

list[np.array]

class mmedit.models.editors.FLAVRNet(num_input_frames, num_output_frames, mid_channels_list=[512, 256, 128, 64], encoder_layers_list=[2, 2, 2, 2], bias=False, norm_cfg=None, join_type='concat', up_mode='transpose', init_cfg=None)

Bases: mmengine.model.BaseModule

PyTorch implementation of FLAVR for video frame interpolation.

Paper:

FLAVR: Flow-Agnostic Video Representations for Fast Frame Interpolation

Ref repo: https://github.com/tarun005/FLAVR

Parameters
  • num_input_frames (int) – Number of input frames.

  • num_output_frames (int) – Number of output frames.

  • mid_channels_list (list[int]) – List of number of mid channels. Default: [512, 256, 128, 64]

  • encoder_layers_list (list[int]) – List of number of layers in encoder. Default: [2, 2, 2, 2]

  • bias (bool) – If True, adds a learnable bias to the conv layers. Default: True

  • norm_cfg (dict | None) – Config dict for normalization layer. Default: None

  • join_type (str) – Join type of tensors from decoder and encoder. Candidates are concat and add. Default: concat

  • up_mode (str) – Up-mode UpConv3d, candidates are transpose and trilinear. Default: transpose

  • init_cfg (dict, optional) – Initialization config dict. Default: None.

forward(images: torch.Tensor)

Forward function.

Parameters

images (Tensor) – Input frames tensor with shape (N, T, C, H, W).

Returns

Output tensor.

Return type

out (Tensor)

class mmedit.models.editors.GCA(data_preprocessor, backbone, loss_alpha=None, init_cfg: Optional[dict] = None, train_cfg=None, test_cfg=None)

Bases: mmedit.models.base_models.BaseMattor

Guided Contextual Attention image matting model.

https://arxiv.org/abs/2001.04069

Parameters
  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor.

  • backbone (dict) – Config of backbone.

  • loss_alpha (dict) – Config of the alpha prediction loss. Default: None.

  • init_cfg (dict, optional) – Initialization config dict. Default: None.

  • train_cfg (dict) – Config of training. In train_cfg, train_backbone should be specified. If the model has a refiner, train_refiner should be specified.

  • test_cfg (dict) – Config of testing. In test_cfg, If the model has a refiner, train_refiner should be specified.

_forward(inputs)

Forward function.

Parameters

inputs (torch.Tensor) – Input tensor.

Returns

Output tensor.

Return type

Tensor

_forward_test(inputs)

Forward function for testing GCA model.

Parameters

inputs (torch.Tensor) – batch input tensor.

Returns

Output tensor of model.

Return type

Tensor

_forward_train(inputs, data_samples)

Forward function for training GCA model.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement]) – data samples collated by data_preprocessor.

Returns

Contains the loss items and batch information.

Return type

dict

class mmedit.models.editors.GGAN(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, noise_size: Optional[int] = None, ema_config: Optional[Dict] = None, loss_config: Optional[Dict] = None)

Bases: mmedit.models.base_models.BaseGAN

Impelmentation of Geomoetric GAN.

<https://arxiv.org/abs/1705.02894>`_(GGAN).

disc_loss(disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor) Tuple

Get disc loss. GGAN use hinge loss to train the discriminator.

Parameters
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • disc_pred_real (Tensor) – Discriminator’s prediction of the real images.

Returns

Loss value and a dict of log variables.

Return type

tuple[Tensor, dict]

gen_loss(disc_pred_fake)

Get disc loss. GGAN use hinge loss to train the generator.

Parameters

disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

Returns

Loss value and a dict of log variables.

Return type

tuple[Tensor, dict]

train_discriminator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train discriminator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

train_generator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train generator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

class mmedit.models.editors.GLEANStyleGANv2(in_size, out_size, img_channels=3, rrdb_channels=64, num_rrdbs=23, style_channels=512, num_mlps=8, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], lr_mlp=0.01, default_style_mode='mix', eval_style_mode='single', mix_prob=0.9, init_cfg=None, fp16_enabled=False, bgr2rgb=False)

Bases: mmengine.model.BaseModule

GLEAN (using StyleGANv2) architecture for super-resolution.

Paper:

GLEAN: Generative Latent Bank for Large-Factor Image Super-Resolution, CVPR, 2021

This method makes use of StyleGAN2 and hence the arguments mostly follow that in ‘StyleGAN2v2Generator’.

In StyleGAN2, we use a static architecture composing of a style mapping module and number of covolutional style blocks. More details can be found in: Analyzing and Improving the Image Quality of StyleGAN CVPR2020.

You can load pretrained model through passing information into pretrained argument. We have already offered official weights as follows:

If you want to load the ema model, you can just use following codes:

# ckpt_http is one of the valid path from http source
generator = StyleGANv2Generator(1024, 512,
                                pretrained=dict(
                                    ckpt_path=ckpt_http,
                                    prefix='generator_ema'))

Of course, you can also download the checkpoint in advance and set ckpt_path with local path. If you just want to load the original generator (not the ema model), please set the prefix with ‘generator’.

Note that our implementation allows to generate BGR image, while the original StyleGAN2 outputs RGB images by default. Thus, we provide bgr2rgb argument to convert the image space.

Parameters
  • in_size (int) – The size of the input image.

  • out_size (int) – The output size of the StyleGAN2 generator.

  • img_channels (int) – Number of channels of the input images. 3 for RGB image and 1 for grayscale image. Default: 3.

  • rrdb_channels (int) – Number of channels of the RRDB features. Default: 64.

  • num_rrdbs (int) – Number of RRDB blocks in the encoder. Default: 23.

  • style_channels (int) – The number of channels for style code. Default: 512.

  • num_mlps (int, optional) – The number of MLP layers. Defaults to 8.

  • channel_multiplier (int, optional) – The mulitiplier factor for the channel number. Defaults to 2.

  • blur_kernel (list, optional) – The blurry kernel. Defaults to [1, 3, 3, 1].

  • lr_mlp (float, optional) – The learning rate for the style mapping layer. Defaults to 0.01.

  • default_style_mode (str, optional) – The default mode of style mixing. In training, we defaultly adopt mixing style mode. However, in the evaluation, we use ‘single’ style mode. [‘mix’, ‘single’] are currently supported. Defaults to ‘mix’.

  • eval_style_mode (str, optional) – The evaluation mode of style mixing. Defaults to ‘single’.

  • mix_prob (float, optional) – Mixing probability. The value should be in range of [0, 1]. Defaults to 0.9.

  • init_cfg (dict, optional) – Initialization config dict. Default: None.

  • fp16_enabled (bool, optional) – Whether to use fp16 training in this module. Defaults to False.

  • bgr2rgb (bool, optional) – Whether to flip the image channel dimension. Defaults to False.

forward(lq)

Forward function.

Parameters

lq (Tensor) – Input LR image with shape (n, c, h, w).

Returns

Output HR image.

Return type

Tensor

class mmedit.models.editors.GLDecoder(in_channels=256, norm_cfg=None, act_cfg=dict(type='ReLU'), out_act='clip')

Bases: torch.nn.Module

Decoder used in Global&Local model.

This implementation follows: Globally and locally Consistent Image Completion

Parameters
  • in_channels (int) – Channel number of input feature.

  • norm_cfg (dict) – Config dict to build norm layer.

  • act_cfg (dict) – Config dict for activation layer, “relu” by default.

  • out_act (str) – Output activation type, “clip” by default. Noted that in our implementation, we clip the output with range [-1, 1].

forward(x)

Forward Function.

Parameters

x (torch.Tensor) – Input tensor with shape of (n, c, h, w).

Returns

Output tensor with shape of (n, c, h’, w’).

Return type

torch.Tensor

class mmedit.models.editors.GLDilationNeck(in_channels=256, conv_type='conv', norm_cfg=None, act_cfg=dict(type='ReLU'), **kwargs)

Bases: torch.nn.Module

Dilation Backbone used in Global&Local model.

This implementation follows: Globally and locally Consistent Image Completion

Parameters
  • in_channels (int) – Channel number of input feature.

  • conv_type (str) – The type of conv module. In DeepFillv1 model, the conv_type should be ‘conv’. In DeepFillv2 model, the conv_type should be ‘gated_conv’.

  • norm_cfg (dict) – Config dict to build norm layer.

  • act_cfg (dict) – Config dict for activation layer, “relu” by default.

  • kwargs (keyword arguments) –

_conv_type
forward(x)

Forward Function.

Parameters

x (torch.Tensor) – Input tensor with shape of (n, c, h, w).

Returns

Output tensor with shape of (n, c, h’, w’).

Return type

torch.Tensor

class mmedit.models.editors.GLEncoder(norm_cfg=None, act_cfg=dict(type='ReLU'))

Bases: torch.nn.Module

Encoder used in Global&Local model.

This implementation follows: Globally and locally Consistent Image Completion

Parameters
  • norm_cfg (dict) – Config dict to build norm layer.

  • act_cfg (dict) – Config dict for activation layer, “relu” by default.

forward(x)

Forward Function.

Parameters

x (torch.Tensor) – Input tensor with shape of (n, c, h, w).

Returns

Output tensor with shape of (n, c, h’, w’).

Return type

torch.Tensor

class mmedit.models.editors.GLEncoderDecoder(encoder=dict(type='GLEncoder'), decoder=dict(type='GLDecoder'), dilation_neck=dict(type='GLDilationNeck'))

Bases: torch.nn.Module

Encoder-Decoder used in Global&Local model.

This implementation follows: Globally and locally Consistent Image Completion

The architecture of the encoder-decoder is: (conv2d x 6) –> (dilated conv2d x 4) –> (conv2d or deconv2d x 7)

Parameters
  • encoder (dict) – Config dict to encoder.

  • decoder (dict) – Config dict to build decoder.

  • dilation_neck (dict) – Config dict to build dilation neck.

forward(x)

Forward Function.

Parameters

x (torch.Tensor) – Input tensor with shape of (n, c, h, w).

Returns

Output tensor with shape of (n, c, h’, w’).

Return type

torch.Tensor

init_weights(pretrained=None, strict=True)

Init weights for models.

Parameters
  • pretrained (str, optional) – Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults: None.

  • strict (boo, optional) – Whether strictly load the pretrained model. Defaults to True.

class mmedit.models.editors.AblatedDiffusionModel(data_preprocessor, unet, diffusion_scheduler, use_fp16=False, classifier=None, classifier_scale=1.0, pretrained_cfgs=None)

Bases: mmengine.model.BaseModel

Guided diffusion Model.

Parameters
  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor.

  • unet (ModelType) – Config of denoising Unet.

  • diffusion_scheduler (ModelType) – Config of diffusion_scheduler scheduler.

  • use_fp16 (bool) – Whether to use fp16 for unet model. Defaults to False.

  • classifier (ModelType) – Config of classifier. Defaults to None.

  • pretrained_cfgs (dict) – Path Config for pretrained weights. Usually this is a dict contains module name and the corresponding ckpt path.Defaults to None.

property device

Get current device of the model.

Returns

The current device of the model.

Return type

torch.device

load_pretrained_models(pretrained_cfgs)

_summary_

Parameters

pretrained_cfgs (_type_) – _description_

infer(init_image=None, batch_size=1, num_inference_steps=1000, labels=None, classifier_scale=0.0, show_progress=False)

_summary_

Parameters
  • init_image (_type_, optional) – _description_. Defaults to None.

  • batch_size (int, optional) – _description_. Defaults to 1.

  • num_inference_steps (int, optional) – _description_. Defaults to 1000.

  • labels (_type_, optional) – _description_. Defaults to None.

  • show_progress (bool, optional) – _description_. Defaults to False.

Returns

_description_

Return type

_type_

forward(inputs: mmedit.utils.typing.ForwardInputs, data_samples: Optional[list] = None, mode: Optional[str] = None) List[mmedit.structures.EditDataSample]

_summary_

Parameters
  • inputs (ForwardInputs) – _description_

  • data_samples (Optional[list], optional) – _description_. Defaults to None.

  • mode (Optional[str], optional) – _description_. Defaults to None.

Returns

_description_

Return type

List[EditDataSample]

val_step(data: dict) mmedit.utils.typing.SampleList

Gets the generated image of given data.

Calls self.data_preprocessor(data) and self(inputs, data_sample, mode=None) in order. Return the generated results which will be passed to evaluator.

Parameters

data (dict) – Data sampled from metric specific sampler. More detials in Metrics and Evaluator.

Returns

Generated image or image dict.

Return type

SampleList

test_step(data: dict) mmedit.utils.typing.SampleList

Gets the generated image of given data. Same as val_step().

Parameters

data (dict) – Data sampled from metric specific sampler. More detials in Metrics and Evaluator.

Returns

Generated image or image dict.

Return type

List[EditDataSample]

train_step(data: dict, optim_wrapper: mmengine.optim.OptimWrapperDict)

_summary_

Parameters
  • data (dict) – _description_

  • optim_wrapper (OptimWrapperDict) – _description_

Returns

_description_

Return type

_type_

class mmedit.models.editors.IconVSRNet(mid_channels=64, num_blocks=30, keyframe_stride=5, padding=2, spynet_pretrained=None, edvr_pretrained=None)

Bases: mmengine.model.BaseModule

IconVSR network structure for video super-resolution.

Support only x4 upsampling.

Paper:

BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond, CVPR, 2021

Parameters
  • mid_channels (int) – Channel number of the intermediate features. Default: 64.

  • num_blocks (int) – Number of residual blocks in each propagation branch. Default: 30.

  • keyframe_stride (int) – Number determining the keyframes. If stride=5, then the (0, 5, 10, 15, …)-th frame will be the keyframes. Default: 5.

  • padding (int) – Number of frames to be padded at two ends of the sequence. 2 for REDS and 3 for Vimeo-90K. Default: 2.

  • spynet_pretrained (str) – Pre-trained model path of SPyNet. Default: None.

  • edvr_pretrained (str) – Pre-trained model path of EDVR (for refill). Default: None.

spatial_padding(lrs)

Apply pdding spatially.

Since the PCD module in EDVR requires that the resolution is a multiple of 4, we apply padding to the input LR images if their resolution is not divisible by 4.

Parameters

lrs (Tensor) – Input LR sequence with shape (n, t, c, h, w).

Returns

Padded LR sequence with shape (n, t, c, h_pad, w_pad).

Return type

Tensor

check_if_mirror_extended(lrs)

Check whether the input is a mirror-extended sequence.

If mirror-extended, the i-th (i=0, …, t-1) frame is equal to the (t-1-i)-th frame.

Parameters

lrs (tensor) – Input LR images with shape (n, t, c, h, w)

compute_refill_features(lrs, keyframe_idx)

Compute keyframe features for information-refill.

Since EDVR-M is used, padding is performed before feature computation. :param lrs: Input LR images with shape (n, t, c, h, w) :type lrs: Tensor :param keyframe_idx: The indices specifying the keyframes. :type keyframe_idx: list(int)

Returns

The keyframe features. Each key corresponds to the

indices in keyframe_idx.

Return type

dict(Tensor)

compute_flow(lrs)

Compute optical flow using SPyNet for feature warping.

Note that if the input is an mirror-extended sequence, ‘flows_forward’ is not needed, since it is equal to ‘flows_backward.flip(1)’.

Parameters

lrs (tensor) – Input LR images with shape (n, t, c, h, w)

Returns

Optical flow. ‘flows_forward’ corresponds to the

flows used for forward-time propagation (current to previous). ‘flows_backward’ corresponds to the flows used for backward-time propagation (current to next).

Return type

tuple(Tensor)

forward(lrs)

Forward function for IconVSR.

Parameters

lrs (Tensor) – Input LR tensor with shape (n, t, c, h, w).

Returns

Output HR tensor with shape (n, t, c, 4h, 4w).

Return type

Tensor

class mmedit.models.editors.DepthwiseIndexBlock(in_channels, norm_cfg=dict(type='BN'), use_context=False, use_nonlinear=False, mode='o2o')

Bases: torch.nn.Module

Depthwise index block.

From https://arxiv.org/abs/1908.00672.

Parameters
  • in_channels (int) – Input channels of the holistic index block.

  • norm_cfg (dict) – Config dict for normalization layer. Default: dict(type=’BN’).

  • use_context (bool, optional) – Whether use larger kernel size in index block. Refer to the paper for more information. Defaults to False.

  • use_nonlinear (bool) – Whether add a non-linear conv layer in the index blocks. Default: False.

  • mode (str) – Mode of index block. Should be ‘o2o’ or ‘m2o’. In ‘o2o’ mode, the group of the conv layers is 1; In ‘m2o’ mode, the group of the conv layer is in_channels.

forward(x)

Forward function.

Parameters

x (Tensor) – Input feature map with shape (N, C, H, W).

Returns

Encoder index feature and decoder index feature.

Return type

tuple(Tensor)

class mmedit.models.editors.HolisticIndexBlock(in_channels, norm_cfg=dict(type='BN'), use_context=False, use_nonlinear=False)

Bases: torch.nn.Module

Holistic Index Block.

From https://arxiv.org/abs/1908.00672.

Parameters
  • in_channels (int) – Input channels of the holistic index block.

  • norm_cfg (dict) – Config dict for normalization layer. Default: dict(type=’BN’).

  • use_context (bool, optional) – Whether use larger kernel size in index block. Refer to the paper for more information. Defaults to False.

  • use_nonlinear (bool) – Whether add a non-linear conv layer in the index block. Default: False.

forward(x)

Forward function.

Parameters

x (Tensor) – Input feature map with shape (N, C, H, W).

Returns

Encoder index feature and decoder index feature.

Return type

tuple(Tensor)

class mmedit.models.editors.IndexedUpsample(in_channels, out_channels, kernel_size=5, norm_cfg=dict(type='BN'), conv_module=ConvModule, init_cfg: Optional[dict] = None)

Bases: mmengine.model.BaseModule

Indexed upsample module.

Parameters
  • in_channels (int) – Input channels.

  • out_channels (int) – Output channels.

  • kernel_size (int, optional) – Kernel size of the convolution layer. Defaults to 5.

  • norm_cfg (dict, optional) – Config dict for normalization layer. Defaults to dict(type=’BN’).

  • conv_module (ConvModule | DepthwiseSeparableConvModule, optional) – Conv module. Defaults to ConvModule.

  • init_cfg (dict, optional) – Initialization config dict. Default: None.

init_weights()

Init weights for the module.

forward(x, shortcut, dec_idx_feat=None)

Forward function.

Parameters
  • x (Tensor) – Input feature map with shape (N, C, H, W).

  • shortcut (Tensor) – The shortcut connection with shape (N, C, H’, W’).

  • dec_idx_feat (Tensor, optional) – The decode index feature map with shape (N, C, H’, W’). Defaults to None.

Returns

Output tensor with shape (N, C, H’, W’).

Return type

Tensor

class mmedit.models.editors.IndexNet(data_preprocessor, backbone, loss_alpha=None, loss_comp=None, init_cfg=None, train_cfg=None, test_cfg=None)

Bases: mmedit.models.base_models.BaseMattor

IndexNet matting model.

This implementation follows: Indices Matter: Learning to Index for Deep Image Matting

Parameters
  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor.

  • backbone (dict) – Config of backbone.

  • train_cfg (dict) – Config of training. In ‘train_cfg’, ‘train_backbone’ should be specified.

  • test_cfg (dict) – Config of testing.

  • init_cfg (dict, optional) – The weight initialized config for BaseModule.

  • loss_alpha (dict) – Config of the alpha prediction loss. Default: None.

  • loss_comp (dict) – Config of the composition loss. Default: None.

_forward(inputs)

Forward function.

Parameters

inputs (torch.Tensor) – Input tensor.

Returns

Output tensor.

Return type

Tensor

_forward_test(inputs)

Forward function for testing IndexNet model.

Parameters

inputs (torch.Tensor) – batch input tensor.

Returns

Output tensor of model.

Return type

Tensor

_forward_train(inputs, data_samples)

Forward function for training IndexNet model.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement]) – data samples collated by data_preprocessor.

Returns

Contains the loss items and batch information.

Return type

dict

class mmedit.models.editors.IndexNetDecoder(in_channels, kernel_size=5, norm_cfg=dict(type='BN'), separable_conv=False, init_cfg: Optional[dict] = None)

Bases: mmengine.model.BaseModule

Decoder for IndexNet.

Please refer to https://arxiv.org/abs/1908.00672.

Parameters
  • in_channels (int) – Input channels of the decoder.

  • kernel_size (int, optional) – Kernel size of the convolution layer. Defaults to 5.

  • norm_cfg (None | dict, optional) – Config dict for normalization layer. Defaults to dict(type=’BN’).

  • separable_conv (bool) – Whether to use separable conv. Default: False.

  • init_cfg (dict, optional) – Initialization config dict. Default: None.

init_weights()

Init weights for the module.

forward(inputs)

Forward function.

Parameters

inputs (dict) – Output dict of IndexNetEncoder.

Returns

Predicted alpha matte of the current batch.

Return type

Tensor

class mmedit.models.editors.IndexNetEncoder(in_channels, out_stride=32, width_mult=1, index_mode='m2o', aspp=True, norm_cfg=dict(type='BN'), freeze_bn=False, use_nonlinear=True, use_context=True, init_cfg: Optional[dict] = None)

Bases: mmengine.model.BaseModule

Encoder for IndexNet.

Please refer to https://arxiv.org/abs/1908.00672.

Parameters
  • in_channels (int, optional) – Input channels of the encoder.

  • out_stride (int, optional) – Output stride of the encoder. For example, if out_stride is 32, the input feature map or image will be downsample to the 1/32 of original size. Defaults to 32.

  • width_mult (int, optional) – Width multiplication factor of channel dimension in MobileNetV2. Defaults to 1.

  • index_mode (str, optional) – Index mode of the index network. It must be one of {‘holistic’, ‘o2o’, ‘m2o’}. If it is set to ‘holistic’, then Holistic index network will be used as the index network. If it is set to ‘o2o’ (or ‘m2o’), when O2O (or M2O) Depthwise index network will be used as the index network. Defaults to ‘m2o’.

  • aspp (bool, optional) – Whether use ASPP module to augment output feature. Defaults to True.

  • norm_cfg (None | dict, optional) – Config dict for normalization layer. Defaults to dict(type=’BN’).

  • freeze_bn (bool, optional) – Whether freeze batch norm layer. Defaults to False.

  • use_nonlinear (bool, optional) – Whether use nonlinearty in index network. Refer to the paper for more information. Defaults to True.

  • use_context (bool, optional) – Whether use larger kernel size in index network. Refer to the paper for more information. Defaults to True.

  • init_cfg (dict, optional) – Initialization config dict. Default: None.

Raises
  • ValueError – out_stride must 16 or 32.

  • NameError – Supported index_mode are {‘holistic’, ‘o2o’, ‘m2o’}.

_make_layer(layer_setting, norm_cfg)
train(mode=True)

Set BatchNorm modules in the model to evaluation mode.

init_weights()

Init weights for the model.

Initialization is based on self._init_cfg

Parameters

pretrained (str, optional) – Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None.

forward(x)

Forward function.

Parameters

x (Tensor) – Input feature map with shape (N, C, H, W).

Returns

Output tensor, shortcut feature and decoder index feature.

Return type

dict

class mmedit.models.editors.InstColorization(data_preprocessor: Union[dict, mmengine.config.Config], image_model, instance_model, fusion_model, color_data_opt, which_direction='AtoB', loss=None, init_cfg=None, train_cfg=None, test_cfg=None)

Bases: mmengine.model.BaseModel

Colorization InstColorization method.

This Colorization is implemented according to the paper:

Instance-aware Image Colorization, CVPR 2020

Adapted from ‘https://github.com/ericsujw/InstColorization.git’ ‘InstColorization/models/train_model’ Copyright (c) 2020, Su, under MIT License.

Parameters
  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor.

  • image_model (dict) – Config for single image model

  • instance_model (dict) – Config for instance model

  • fusion_model (dict) – Config for fusion model

  • color_data_opt (dict) – Option for colorspace conversion

  • which_direction (str) – AtoB or BtoA

  • loss (dict) – Config for loss.

  • init_cfg (str) – Initialization config dict. Default: None.

  • train_cfg (dict) – Config for training. Default: None.

  • test_cfg (dict) – Config for testing. Default: None.

forward(inputs: torch.Tensor, data_samples: Optional[List[mmedit.structures.EditDataSample]] = None, mode: str = 'tensor', **kwargs)

Returns losses or predictions of training, validation, testing, and simple inference process.

forward method of BaseModel is an abstract method, its subclasses must implement this method.

Accepts inputs and data_samples processed by data_preprocessor, and returns results according to mode arguments.

During non-distributed training, validation, and testing process, forward will be called by BaseModel.train_step, BaseModel.val_step and BaseModel.val_step directly.

During distributed data parallel training process, MMSeparateDistributedDataParallel.train_step will first call DistributedDataParallel.forward to enable automatic gradient synchronization, and then call forward to get training loss.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

  • mode (str) –

    mode should be one of loss, predict and tensor. Default: ‘tensor’.

    • loss: Called by train_step and return loss dict used for logging

    • predict: Called by val_step and test_step and return list of BaseDataElement results used for computing metric.

    • tensor: Called by custom use to get Tensor type results.

Returns

  • If mode == loss, return a dict of loss tensor used for backward and logging.

  • If mode == predict, return a list of BaseDataElement for computing metric and getting inference result.

  • If mode == tensor, return a tensor or tuple of tensor or dict or tensor for custom use.

Return type

ForwardResults

convert_to_datasample(inputs, data_samples)
abstract forward_train(inputs, data_samples=None, **kwargs)

Forward function for training.

abstract train_step(data: List[dict], optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor]

Train step function.

Parameters
  • data (List[dict]) – Batch of data as input.

  • optim_wrapper (dict[torch.optim.Optimizer]) – Dict with optimizers for generator and discriminator (if have).

Returns

Dict with loss, information for logger, the number of

samples and results for visualization.

Return type

dict

forward_inference(inputs, data_samples=None, **kwargs)

Forward inference. Returns predictions of validation, testing.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

Returns

predictions.

Return type

List[EditDataSample]

forward_tensor(inputs, data_samples)

Forward function in tensor mode.

Parameters
  • inputs (torch.Tensor) – Input tensor.

  • data_sample (dict) – Dict contains data sample.

Returns

Dict contains output results.

Return type

dict

class mmedit.models.editors.LIIF(generator, pixel_loss, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None)

Bases: mmedit.models.base_models.BaseEditModel

LIIF model for single image super-resolution.

Paper: Learning Continuous Image Representation with

Local Implicit Image Function

Parameters
  • generator (dict) – Config for the generator.

  • pixel_loss (dict) – Config for the pixel loss.

  • pretrained (str) – Path for pretrained model. Default: None.

  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor.

forward_tensor(inputs, data_samples=None, **kwargs)

Forward tensor. Returns result of simple forward.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

Returns

result of simple forward.

Return type

Tensor

forward_inference(inputs, data_samples=None, **kwargs)

Forward inference. Returns predictions of validation, testing, and simple inference.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

Returns

predictions.

Return type

List[EditDataSample]

class mmedit.models.editors.MLPRefiner(in_dim, out_dim, hidden_list)

Bases: mmengine.model.BaseModule

Multilayer perceptrons (MLPs), refiner used in LIIF.

Parameters
  • in_dim (int) – Input dimension.

  • out_dim (int) – Output dimension.

  • hidden_list (list[int]) – List of hidden dimensions.

forward(x)

Forward function.

Parameters

x (Tensor) – The input of MLP.

Returns

The output of MLP.

Return type

Tensor

class mmedit.models.editors.LSGAN(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, noise_size: Optional[int] = None, ema_config: Optional[Dict] = None, loss_config: Optional[Dict] = None)

Bases: mmedit.models.base_models.BaseGAN

Impelmentation of Least Squares Generative Adversarial Networks.

Paper link: https://arxiv.org/pdf/1611.04076.pdf

Detailed architecture can be found in :class:~`mmgen.models.architectures.lsgan.generator_generator.LSGANGenerator` # noqa and :class:~`mmgen.models.architectures.lsgan.generator_discriminator.LSGANDiscriminator` # noqa

disc_loss(disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor) Tuple

Get disc loss. LSGAN use the least squares loss to train the discriminator.

\[L_{D}=\left(D\left(X_{\text {data }}\right)-1\right)^{2} +(D(G(z)))^{2}\]
Parameters
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • disc_pred_real (Tensor) – Discriminator’s prediction of the real images.

Returns

Loss value and a dict of log variables.

Return type

tuple[Tensor, dict]

gen_loss(disc_pred_fake: torch.Tensor) Tuple

Get gen loss. LSGAN use the least squares loss to train the generator.

\[L_{G}=(D(G(z))-1)^{2}\]
Parameters

disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

Returns

Loss value and a dict of log variables.

Return type

tuple[Tensor, dict]

train_discriminator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train discriminator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

train_generator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train generator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

class mmedit.models.editors.MSPIEStyleGAN2(*args, train_settings=dict(), **kwargs)

Bases: mmedit.models.editors.stylegan2.StyleGAN2

MS-PIE StyleGAN2.

In this GAN, we adopt the MS-PIE training schedule so that multi-scale images can be generated with a single generator. Details can be found in: Positional Encoding as Spatial Inductive Bias in GANs, CVPR2021.

Parameters

train_settings (dict) – Config for training settings. Defaults to dict().

train_step(data: dict, optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor]

Train GAN model. In the training of GAN models, generator and discriminator are updated alternatively. In MMGeneration’s design, self.train_step is called with data input. Therefore we always update discriminator, whose updating is relay on real data, and then determine if the generator needs to be updated based on the current number of iterations. More details about whether to update generator can be found in should_gen_update().

Parameters
  • data (dict) – Data sampled from dataloader.

  • optim_wrapper (OptimWrapperDict) – OptimWrapperDict instance contains OptimWrapper of generator and discriminator.

Returns

A dict of tensor for logging.

Return type

Dict[str, torch.Tensor]

train_generator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train generator.

Parameters
  • inputs (TrainInput) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

train_discriminator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train discriminator.

Parameters
  • inputs (TrainInput) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

class mmedit.models.editors.PESinGAN(generator: ModelType, discriminator: Optional[ModelType], data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, num_scales: Optional[int] = None, fixed_noise_with_pad: bool = False, first_fixed_noises_ch: int = 1, iters_per_scale: int = 200, noise_weight_init: int = 0.1, lr_scheduler_args: Optional[dict] = None, test_pkl_data: Optional[str] = None, ema_confg: Optional[dict] = None)

Bases: mmedit.models.editors.singan.SinGAN

Positional Encoding in SinGAN.

This modified SinGAN is used to reimplement the experiments in: Positional Encoding as Spatial Inductive Bias in GANs, CVPR2021.

construct_fixed_noises()

Construct the fixed noises list used in SinGAN.

class mmedit.models.editors.NAFBaseline(img_channel=3, mid_channels=16, middle_blk_num=1, enc_blk_nums=[1, 1, 1, 28], dec_blk_nums=[1, 1, 1, 1], dw_expand=1, ffn_expand=2)

Bases: mmengine.model.BaseModule

The original version of Baseline model in “Simple Baseline for Image Restoration”.

Parameters
  • img_channels (int) – Channel number of inputs.

  • mid_channels (int) – Channel number of intermediate features.

  • middle_blk_num (int) – Number of middle blocks.

  • enc_blk_nums (List of int) – Number of blocks for each encoder.

  • dec_blk_nums (List of int) – Number of blocks for each decoder.

forward(inp)

Forward function.

Parameters

inp – input tensor image with (B, C, H, W) shape

check_image_size(x)

Check image size and pad images so that it has enough dimension do downsample.

Parameters

x – input tensor image with (B, C, H, W) shape.

class mmedit.models.editors.NAFBaselineLocal(*args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs)

Bases: mmedit.models.editors.nafnet.naf_avgpool2d.Local_Base, NAFBaseline

The original version of Baseline model in “Simple Baseline for Image Restoration”.

Parameters
  • img_channels (int) – Channel number of inputs.

  • mid_channels (int) – Channel number of intermediate features.

  • middle_blk_num (int) – Number of middle blocks.

  • enc_blk_nums (List of int) – Number of blocks for each encoder.

  • dec_blk_nums (L`ist of int) – Number of blocks for each decoder.

class mmedit.models.editors.NAFNet(img_channel=3, mid_channels=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[])

Bases: mmengine.model.BaseModule

NAFNet.

The original version of NAFNet in “Simple Baseline for Image Restoration”.

Parameters
  • img_channels (int) – Channel number of inputs.

  • mid_channels (int) – Channel number of intermediate features.

  • middle_blk_num (int) – Number of middle blocks.

  • enc_blk_nums (List of int) – Number of blocks for each encoder.

  • dec_blk_nums (List of int) – Number of blocks for each decoder.

forward(inp)

Forward function.

Parameters

inp – input tensor image with (B, C, H, W) shape

check_image_size(x)

Check image size and pad images so that it has enough dimension do downsample.

Parameters

x – input tensor image with (B, C, H, W) shape.

class mmedit.models.editors.NAFNetLocal(*args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs)

Bases: mmedit.models.editors.nafnet.naf_avgpool2d.Local_Base, NAFNet

The original version of NAFNetLocal in “Simple Baseline for Image Restoration”.

NAFNetLocal uses local average pooling modules than NAFNet.

Parameters
  • img_channels (int) – Channel number of inputs.

  • mid_channels (int) – Channel number of intermediate features.

  • middle_blk_num (int) – Number of middle blocks.

  • enc_blk_nums (List of int) – Number of blocks for each encoder.

  • dec_blk_nums (List of int) – Number of blocks for each decoder.

class mmedit.models.editors.MaskConvModule(*args, **kwargs)

Bases: mmcv.cnn.ConvModule

Mask convolution module.

This is a simple wrapper for mask convolution like: ‘partial conv’. Convolutions in this module always need a mask as extra input.

Parameters
  • in_channels (int) – Same as nn.Conv2d.

  • out_channels (int) – Same as nn.Conv2d.

  • kernel_size (int or tuple[int]) – Same as nn.Conv2d.

  • stride (int or tuple[int]) – Same as nn.Conv2d.

  • padding (int or tuple[int]) – Same as nn.Conv2d.

  • dilation (int or tuple[int]) – Same as nn.Conv2d.

  • groups (int) – Same as nn.Conv2d.

  • bias (bool or str) – If specified as auto, it will be decided by the norm_cfg. Bias will be set as True if norm_cfg is None, otherwise False.

  • conv_cfg (dict) – Config dict for convolution layer.

  • norm_cfg (dict) – Config dict for normalization layer.

  • act_cfg (dict) – Config dict for activation layer, “relu” by default.

  • inplace (bool) – Whether to use inplace mode for activation.

  • with_spectral_norm (bool) – Whether use spectral norm in conv module.

  • padding_mode (str) – If the padding_mode has not been supported by current Conv2d in Pytorch, we will use our own padding layer instead. Currently, we support [‘zeros’, ‘circular’] with official implementation and [‘reflect’] with our own implementation. Default: ‘zeros’.

  • order (tuple[str]) – The order of conv/norm/activation layers. It is a sequence of “conv”, “norm” and “act”. Examples are (“conv”, “norm”, “act”) and (“act”, “conv”, “norm”).

supported_conv_list = ['PConv']
forward(x, mask=None, activate=True, norm=True, return_mask=True)

Forward function for partial conv2d.

Parameters
  • x (torch.Tensor) – Tensor with shape of (n, c, h, w).

  • mask (torch.Tensor) – Tensor with shape of (n, c, h, w) or (n, 1, h, w). If mask is not given, the function will work as standard conv2d. Default: None.

  • activate (bool) – Whether use activation layer.

  • norm (bool) – Whether use norm layer.

  • return_mask (bool) – If True and mask is not None, the updated mask will be returned. Default: True.

Returns

Result Tensor or 2-tuple of

Tensor: Results after partial conv.

Tensor: Updated mask will be returned if mask is given and return_mask is True.

Return type

Tensor or tuple

class mmedit.models.editors.PartialConv2d(*args, multi_channel=False, eps=1e-08, **kwargs)

Bases: torch.nn.Conv2d

Implementation for partial convolution.

Image Inpainting for Irregular Holes Using Partial Convolutions [https://arxiv.org/abs/1804.07723]

Parameters
  • multi_channel (bool) – If True, the mask is multi-channel. Otherwise, the mask is single-channel.

  • eps (float) – Need to be changed for mixed precision training. For mixed precision training, you need change 1e-8 to 1e-6.

forward(input, mask=None, return_mask=True)

Forward function for partial conv2d.

Parameters
  • input (torch.Tensor) – Tensor with shape of (n, c, h, w).

  • mask (torch.Tensor) – Tensor with shape of (n, c, h, w) or (n, 1, h, w). If mask is not given, the function will work as standard conv2d. Default: None.

  • return_mask (bool) – If True and mask is not None, the updated mask will be returned. Default: True.

Returns

Results after partial conv. torch.Tensor : Updated mask will be returned if mask is given and return_mask is True.

Return type

torch.Tensor

class mmedit.models.editors.PConvDecoder(num_layers=7, interpolation='nearest', conv_cfg=dict(type='PConv', multi_channel=True), norm_cfg=dict(type='BN'))

Bases: torch.nn.Module

Decoder with partial conv.

About the details for this architecture, pls see: Image Inpainting for Irregular Holes Using Partial Convolutions

Parameters
  • num_layers (int) – The number of convolutional layers. Default: 7.

  • interpolation (str) – The upsample mode. Default: ‘nearest’.

  • conv_cfg (dict) – Config for convolution module. Default: {‘type’: ‘PConv’, ‘multi_channel’: True}.

  • norm_cfg (dict) – Config for norm layer. Default: {‘type’: ‘BN’}.

forward(input_dict)

Forward Function.

Parameters

input_dict (dict | torch.Tensor) – Input dict with middle features or torch.Tensor.

Returns

Output tensor with shape of (n, c, h, w).

Return type

torch.Tensor

class mmedit.models.editors.PConvEncoder(in_channels=3, num_layers=7, conv_cfg=dict(type='PConv', multi_channel=True), norm_cfg=dict(type='BN', requires_grad=True), norm_eval=False)

Bases: torch.nn.Module

Encoder with partial conv.

About the details for this architecture, pls see: Image Inpainting for Irregular Holes Using Partial Convolutions

Parameters
  • in_channels (int) – The number of input channels. Default: 3.

  • num_layers (int) – The number of convolutional layers. Default: 7.

  • conv_cfg (dict) – Config for convolution module. Default: {‘type’: ‘PConv’, ‘multi_channel’: True}.

  • norm_cfg (dict) – Config for norm layer. Default: {‘type’: ‘BN’}.

  • norm_eval (bool) – Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effective on Batch Norm and its variants only. Default: False.

train(mode=True)

Set BatchNorm modules in the model to evaluation mode.

forward(x, mask)

Forward function for partial conv encoder.

Parameters
  • x (torch.Tensor) – Masked image with shape (n, c, h, w).

  • mask (torch.Tensor) – Mask tensor with shape (n, c, h, w).

Returns

Contains the results and middle level features in this module. hidden_feats contain the middle feature maps and hidden_masks store updated masks.

Return type

dict

class mmedit.models.editors.PConvEncoderDecoder(encoder, decoder)

Bases: torch.nn.Module

Encoder-Decoder with partial conv module.

Parameters
  • encoder (dict) – Config of the encoder.

  • decoder (dict) – Config of the decoder.

forward(x, mask_in)

Forward Function.

Parameters
  • x (torch.Tensor) – Input tensor with shape of (n, c, h, w).

  • mask_in (torch.Tensor) – Input tensor with shape of (n, c, h, w).

Returns

Output tensor with shape of (n, c, h’, w’).

Return type

torch.Tensor

init_weights(pretrained=None, strict=True)

Init weights for models.

Parameters
  • pretrained (str, optional) – Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults: None.

  • strict (boo, optional) – Whether strictly load the pretrained model. Defaults to True.

class mmedit.models.editors.PConvInpaintor(data_preprocessor: Union[dict, mmengine.config.Config], encdec, disc=None, loss_gan=None, loss_gp=None, loss_disc_shift=None, loss_composed_percep=None, loss_out_percep=False, loss_l1_hole=None, loss_l1_valid=None, loss_tv=None, train_cfg=None, test_cfg=None, init_cfg: Optional[dict] = None)

Bases: mmedit.models.base_models.OneStageInpaintor

Inpaintor for Partial Convolution method.

This inpaintor is implemented according to the paper: Image inpainting for irregular holes using partial convolutions

forward_test(inputs, data_samples)

Forward function for testing.

Parameters
  • inputs (torch.Tensor) – Input tensor.

  • data_samples (List[dict]) – List of data sample dict.

Returns

Contain output results and eval metrics (if have).

Return type

dict

forward_tensor(inputs, data_samples)

Forward function in tensor mode.

Parameters
  • inputs (torch.Tensor) – Input tensor.

  • data_sample (dict) – Dict contains data sample.

Returns

Dict contains output results.

Return type

dict

train_step(data: List[dict], optim_wrapper)

Train step function.

In this function, the inpaintor will finish the train step following the pipeline:

  1. get fake res/image

  2. optimize discriminator (if have)

  3. optimize generator

If self.train_cfg.disc_step > 1, the train step will contain multiple iterations for optimizing discriminator with different input data and only one iteration for optimizing gerator after disc_step iterations for discriminator.

Parameters
  • data (List[dict]) – Batch of data as input.

  • optim_wrapper (dict[torch.optim.Optimizer]) – Dict with optimizers for generator and discriminator (if have).

Returns

Dict with loss, information for logger, the number of samples and results for visualization.

Return type

dict

class mmedit.models.editors.ProgressiveGrowingGAN(generator, discriminator, data_preprocessor, nkimgs_per_scale, noise_size=None, interp_real=None, transition_kimgs: int = 600, prev_stage: int = 0, ema_config: Optional[Dict] = None)

Bases: mmedit.models.base_models.BaseGAN

Progressive Growing Unconditional GAN.

In this GAN model, we implement progressive growing training schedule, which is proposed in Progressive Growing of GANs for improved Quality, Stability and Variation, ICLR 2018.

We highly recommend to use GrowScaleImgDataset for saving computational load in data pre-processing.

Notes for using PGGAN:

  1. In official implementation, Tero uses gradient penalty with norm_mode="HWC"

  2. We do not implement minibatch_repeats where has been used in official Tensorflow implementation.

Notes for resuming progressive growing GANs: Users should specify the prev_stage in train_cfg. Otherwise, the model is possible to reset the optimizer status, which will bring inferior performance. For example, if your model is resumed from the 256 stage, you should set train_cfg=dict(prev_stage=256).

Parameters
  • generator (dict) – Config for generator.

  • discriminator (dict) – Config for discriminator.

forward(inputs: mmedit.utils.typing.ForwardInputs, data_samples: Optional[list] = None, mode: Optional[str] = None) mmedit.utils.typing.SampleList

Sample images from noises by using the generator.

Parameters
  • batch_inputs (ForwardInputs) – Dict containing the necessary information (e.g. noise, num_batches, mode) to generate image.

  • data_samples (Optional[list]) – Data samples collated by data_preprocessor. Defaults to None.

  • mode (Optional[str]) – mode is not used in ProgressiveGrowingGAN. Defaults to None.

Returns

A list of EditDataSample contain generated results.

Return type

SampleList

train_discriminator(inputs: torch.Tensor, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train discriminator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

disc_loss(disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor, fake_data: torch.Tensor, real_data: torch.Tensor) Tuple[torch.Tensor, dict]

Get disc loss. PGGAN use WGAN-GP’s loss and discriminator shift loss to train the discriminator.

Parameters
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • disc_pred_real (Tensor) – Discriminator’s prediction of the real images.

  • fake_data (Tensor) – Generated images, used to calculate gradient penalty.

  • real_data (Tensor) – Real images, used to calculate gradient penalty.

Returns

Loss value and a dict of log variables.

Return type

Tuple[Tensor, dict]

train_generator(inputs: torch.Tensor, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train generator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

gen_loss(disc_pred_fake: torch.Tensor) Tuple[torch.Tensor, dict]

Generator loss for PGGAN. PGGAN use WGAN’s loss to train the generator.

Parameters
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • recon_imgs (Tensor) – Reconstructive images.

Returns

Loss value and a dict of log variables.

Return type

Tuple[Tensor, dict]

train_step(data: dict, optim_wrapper: mmengine.optim.OptimWrapperDict)

Train step function.

This function implements the standard training iteration for asynchronous adversarial training. Namely, in each iteration, we first update discriminator and then compute loss for generator with the newly updated discriminator.

As for distributed training, we use the reducer from ddp to synchronize the necessary params in current computational graph.

Parameters
  • data_batch (dict) – Input data from dataloader.

  • optimizer (dict) – Dict contains optimizer for generator and discriminator.

  • ddp_reducer (Reducer | None, optional) – Reducer from ddp. It is used to prepare for backward() in ddp. Defaults to None.

  • running_status (dict | None, optional) – Contains necessary basic information for training, e.g., iteration number. Defaults to None.

Returns

Contains ‘log_vars’, ‘num_samples’, and ‘results’.

Return type

dict

class mmedit.models.editors.Pix2Pix(*args, **kwargs)

Bases: mmedit.models.base_models.BaseTranslationModel

Pix2Pix model for paired image-to-image translation.

Ref:

Image-to-Image Translation with Conditional Adversarial Networks

forward_test(img, target_domain, **kwargs)

Forward function for testing.

Parameters
  • img (tensor) – Input image tensor.

  • target_domain (str) – Target domain of output image.

  • kwargs (dict) – Other arguments.

Returns

Forward results.

Return type

dict

_get_disc_loss(outputs)

Get the loss of discriminator.

Parameters

outputs (dict) – A dict of output.

Returns

Loss and a dict of log of loss terms.

Return type

Tuple

_get_gen_loss(outputs)

Get the loss of generator.

Parameters

outputs (dict) – A dict of output.

Returns

Loss and a dict of log of loss terms.

Return type

Tuple

train_step(data, optim_wrapper=None)

Training step function.

Parameters
  • data_batch (dict) – Dict of the input data batch.

  • optimizer (dict[torch.optim.Optimizer]) – Dict of optimizers for the generator and discriminator.

  • ddp_reducer (Reducer | None, optional) – Reducer from ddp. It is used to prepare for backward() in ddp. Defaults to None.

  • running_status (dict | None, optional) – Contains necessary basic information for training, e.g., iteration number. Defaults to None.

Returns

Dict of loss, information for logger, the number of samples and results for visualization.

Return type

dict

test_step(data: dict) mmedit.utils.typing.SampleList

Gets the generated image of given data. Same as val_step().

Parameters

data (dict) – Data sampled from metric specific sampler. More detials in Metrics and Evaluator.

Returns

Generated image or image dict.

Return type

List[EditDataSample]

val_step(data: dict) mmedit.utils.typing.SampleList

Gets the generated image of given data. Same as val_step().

Parameters

data (dict) – Data sampled from metric specific sampler. More detials in Metrics and Evaluator.

Returns

Generated image or image dict.

Return type

List[EditDataSample]

class mmedit.models.editors.PlainDecoder(in_channels, init_cfg: Optional[dict] = None)

Bases: mmengine.model.BaseModule

Simple decoder from Deep Image Matting.

Parameters
  • in_channels (int) – Channel num of input features.

  • init_cfg (dict, optional) – Initialization config dict. efaults to None.

init_weights()

Init weights for the module.

forward(inputs)

Forward function of PlainDecoder.

Parameters

inputs (dict) –

Output dictionary of the VGG encoder containing:

  • out (Tensor): Output of the VGG encoder.

  • max_idx_1 (Tensor): Index of the first maxpooling layer in the VGG encoder.

  • max_idx_2 (Tensor): Index of the second maxpooling layer in the VGG encoder.

  • max_idx_3 (Tensor): Index of the third maxpooling layer in the VGG encoder.

  • max_idx_4 (Tensor): Index of the fourth maxpooling layer in the VGG encoder.

  • max_idx_5 (Tensor): Index of the fifth maxpooling layer in the VGG encoder.

Returns

Output tensor.

Return type

Tensor

class mmedit.models.editors.PlainRefiner(conv_channels=64, init_cfg=None)

Bases: mmengine.model.BaseModule

Simple refiner from Deep Image Matting.

Parameters
  • conv_channels (int) – Number of channels produced by the three main convolutional layer. Default: 64.

  • pretrained (str) – Name of pretrained model. Default: None.

init_weights()

Init weights for the module.

forward(x, raw_alpha)

Forward function.

Parameters
  • x (Tensor) – The input feature map of refiner.

  • raw_alpha (Tensor) – The raw predicted alpha matte.

Returns

The refined alpha matte.

Return type

Tensor

class mmedit.models.editors.RDNNet(in_channels, out_channels, mid_channels=64, num_blocks=16, upscale_factor=4, num_layers=8, channel_growth=64)

Bases: mmengine.model.BaseModule

RDN model for single image super-resolution.

Paper: Residual Dense Network for Image Super-Resolution

Adapted from ‘https://github.com/yjn870/RDN-pytorch.git’ ‘RDN-pytorch/blob/master/models.py’ Copyright (c) 2021, JaeYun Yeo, under MIT License.

Most of the implementation follows the implementation in: ‘https://github.com/sanghyun-son/EDSR-PyTorch.git’ ‘EDSR-PyTorch/blob/master/src/model/rdn.py’ Copyright (c) 2017, sanghyun-son, under MIT license.

Parameters
  • in_channels (int) – Channel number of inputs.

  • out_channels (int) – Channel number of outputs.

  • mid_channels (int) – Channel number of intermediate features. Default: 64.

  • num_blocks (int) – Block number in the trunk network. Default: 16.

  • upscale_factor (int) – Upsampling factor. Support 2^n and 3. Default: 4.

  • num_layer (int) – Layer number in the Residual Dense Block. Default: 8.

  • channel_growth (int) – Channels growth in each layer of RDB. Default: 64.

forward(x)

Forward function.

Parameters

x (Tensor) – Input tensor with shape (n, c, h, w).

Returns

Forward results.

Return type

Tensor

class mmedit.models.editors.RealBasicVSR(generator, discriminator=None, gan_loss=None, pixel_loss=None, cleaning_loss=None, perceptual_loss=None, is_use_sharpened_gt_in_pixel=False, is_use_sharpened_gt_in_percep=False, is_use_sharpened_gt_in_gan=False, is_use_ema=False, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None)

Bases: mmedit.models.editors.real_esrgan.RealESRGAN

RealBasicVSR model for real-world video super-resolution.

Ref: Investigating Tradeoffs in Real-World Video Super-Resolution, arXiv

Parameters
  • generator (dict) – Config for the generator.

  • discriminator (dict, optional) – Config for the discriminator. Default: None.

  • gan_loss (dict, optional) – Config for the gan loss. Note that the loss weight in gan loss is only for the generator.

  • pixel_loss (dict, optional) – Config for the pixel loss. Default: None.

  • cleaning_loss (dict, optional) – Config for the image cleaning loss. Default: None.

  • perceptual_loss (dict, optional) – Config for the perceptual loss. Default: None.

  • is_use_sharpened_gt_in_pixel (bool, optional) – Whether to use the image sharpened by unsharp masking as the GT for pixel loss. Default: False.

  • is_use_sharpened_gt_in_percep (bool, optional) – Whether to use the image sharpened by unsharp masking as the GT for perceptual loss. Default: False.

  • is_use_sharpened_gt_in_gan (bool, optional) – Whether to use the image sharpened by unsharp masking as the GT for adversarial loss. Default: False.

  • train_cfg (dict) – Config for training. Default: None. You may change the training of gan by setting: disc_steps: how many discriminator updates after one generate update; disc_init_steps: how many discriminator updates at the start of the training. These two keys are useful when training with WGAN.

  • test_cfg (dict) – Config for testing. Default: None.

  • init_cfg (dict, optional) – The weight initialized config for BaseModule. Default: None.

  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor. Default: None.

extract_gt_data(data_samples)

extract gt data from data samples.

Parameters

data_samples (list) – List of EditDataSample.

Returns

Extract gt data.

Return type

Tensor

g_step(batch_outputs, batch_gt_data)

G step of GAN: Calculate losses of generator.

Parameters
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tuple[Tensor]) – Batch GT data.

Returns

Dict of losses.

Return type

dict

d_step_with_optim(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor, optim_wrapper: mmengine.optim.OptimWrapperDict)

D step with optim of GAN: Calculate losses of discriminator and run optim.

Parameters
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

  • optim_wrapper (OptimWrapperDict) – Optim wrapper dict.

Returns

Dict of parsed losses.

Return type

dict

forward_train(batch_inputs, data_samples=None)

Forward Train.

Run forward of generator with return_lqs=True

Parameters
  • batch_inputs (Tensor) – Batch inputs.

  • data_samples (List[EditDataSample]) – Data samples of Editing. Default:None

Returns

Result of generator.

(outputs, lqs)

Return type

Tuple[Tensor]

class mmedit.models.editors.RealBasicVSRNet(mid_channels=64, num_propagation_blocks=20, num_cleaning_blocks=20, dynamic_refine_thres=255, spynet_pretrained=None, is_fix_cleaning=False, is_sequential_cleaning=False)

Bases: mmengine.model.BaseModule

RealBasicVSR network structure for real-world video super-resolution.

Support only x4 upsampling.

Paper:

Investigating Tradeoffs in Real-World Video Super-Resolution, arXiv

Parameters
  • mid_channels (int, optional) – Channel number of the intermediate features. Default: 64.

  • num_propagation_blocks (int, optional) – Number of residual blocks in each propagation branch. Default: 20.

  • num_cleaning_blocks (int, optional) – Number of residual blocks in the image cleaning module. Default: 20.

  • dynamic_refine_thres (int, optional) – Stop cleaning the images when the residue is smaller than this value. Default: 255.

  • spynet_pretrained (str, optional) – Pre-trained model path of SPyNet. Default: None.

  • is_fix_cleaning (bool, optional) – Whether to fix the weights of the image cleaning module during training. Default: False.

  • is_sequential_cleaning (bool, optional) – Whether to clean the images sequentially. This is used to save GPU memory, but the speed is slightly slower. Default: False.

forward(lqs, return_lqs=False)

Forward function for BasicVSR++.

Parameters
  • lqs (tensor) – Input low quality (LQ) sequence with shape (n, t, c, h, w).

  • return_lqs (bool) – Whether to return LQ sequence. Default: False.

Returns

Output HR sequence.

Return type

Tensor

class mmedit.models.editors.RealESRGAN(generator, discriminator=None, gan_loss=None, pixel_loss=None, perceptual_loss=None, is_use_sharpened_gt_in_pixel=False, is_use_sharpened_gt_in_percep=False, is_use_sharpened_gt_in_gan=False, is_use_ema=True, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None)

Bases: mmedit.models.editors.srgan.SRGAN

Real-ESRGAN model for single image super-resolution.

Ref: Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data, 2021.

Note: generator_ema is realized in EMA_HOOK

Parameters
  • generator (dict) – Config for the generator.

  • discriminator (dict, optional) – Config for the discriminator. Default: None.

  • gan_loss (dict, optional) – Config for the gan loss. Note that the loss weight in gan loss is only for the generator.

  • pixel_loss (dict, optional) – Config for the pixel loss. Default: None.

  • perceptual_loss (dict, optional) – Config for the perceptual loss. Default: None.

  • is_use_sharpened_gt_in_pixel (bool, optional) – Whether to use the image sharpened by unsharp masking as the GT for pixel loss. Default: False.

  • is_use_sharpened_gt_in_percep (bool, optional) – Whether to use the image sharpened by unsharp masking as the GT for perceptual loss. Default: False.

  • is_use_sharpened_gt_in_gan (bool, optional) – Whether to use the image sharpened by unsharp masking as the GT for adversarial loss. Default: False.

  • is_use_ema (bool, optional) – When to apply exponential moving average on the network weights. Default: True.

  • train_cfg (dict) – Config for training. Default: None. You may change the training of gan by setting: disc_steps: how many discriminator updates after one generate update; disc_init_steps: how many discriminator updates at the start of the training. These two keys are useful when training with WGAN.

  • test_cfg (dict) – Config for testing. Default: None.

  • init_cfg (dict, optional) – The weight initialized config for BaseModule. Default: None.

  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor. Default: None.

forward_tensor(inputs, data_samples=None, training=False)

Forward tensor. Returns result of simple forward.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

  • training (bool) – Whether is training. Default: False.

Returns

result of simple forward.

Return type

Tensor

g_step(batch_outputs, batch_gt_data)

G step of GAN: Calculate losses of generator.

Parameters
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tuple[Tensor]) – Batch GT data.

Returns

Dict of losses.

Return type

dict

d_step_real(batch_outputs, batch_gt_data: torch.Tensor)

Real part of D step.

Parameters
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tuple[Tensor]) – Batch GT data.

Returns

Real part of gan_loss for discriminator.

Return type

Tensor

d_step_fake(batch_outputs, batch_gt_data)

Fake part of D step.

Parameters
  • batch_outputs (Tensor) – Output of generator.

  • batch_gt_data (Tuple[Tensor]) – Batch GT data.

Returns

Fake part of gan_loss for discriminator.

Return type

Tensor

extract_gt_data(data_samples)

extract gt data from data samples.

Parameters

data_samples (list) – List of EditDataSample.

Returns

Extract gt data.

Return type

Tensor

class mmedit.models.editors.UNetDiscriminatorWithSpectralNorm(in_channels, mid_channels=64, skip_connection=True)

Bases: mmengine.model.BaseModule

A U-Net discriminator with spectral normalization.

Parameters
  • in_channels (int) – Channel number of the input.

  • mid_channels (int, optional) – Channel number of the intermediate features. Default: 64.

  • skip_connection (bool, optional) – Whether to use skip connection. Default: True.

forward(img)

Forward function.

Parameters

img (Tensor) – Input tensor with shape (n, c, h, w).

Returns

Forward results.

Return type

Tensor

class mmedit.models.editors.SAGAN(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, noise_size: Optional[int] = 128, num_classes: Optional[int] = None, ema_config: Optional[Dict] = None)

Bases: mmedit.models.base_models.BaseConditionalGAN

Impelmentation of Self-Attention Generative Adversarial Networks.

<https://arxiv.org/abs/1805.08318>`_ (SAGAN), Spectral Normalization for Generative Adversarial Networks (SNGAN), and cGANs with Projection Discriminator (Proj-GAN).

Detailed architecture can be found in :class:~`mmedit.models.editors.sagan.sagan_generator.SNGANGenerator` # noqa and :class:~`mmedit.models.editors.sagan.sagan_discriminator.ProjDiscriminator` # noqa

Parameters
  • generator (ModelType) – The config or model of the generator.

  • discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.

  • data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or GenDataPreprocessor.

  • generator_steps (int) – Number of times the generator was completely updated before the discriminator is updated. Defaults to 1.

  • discriminator_steps (int) – Number of times the discriminator was completely updated before the generator is updated. Defaults to 1.

  • noise_size (Optional[int]) – Size of the input noise vector. Default to 128.

  • num_classes (Optional[int]) – The number classes you would like to generate. Defaults to None.

  • ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.

disc_loss(disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor) Tuple[torch.Tensor, dict]

Get disc loss. SAGAN, SNGAN and Proj-GAN use hinge loss to train the discriminator.

Parameters
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • disc_pred_real (Tensor) – Discriminator’s prediction of the real images.

Returns

Loss value and a dict of log variables.

Return type

Tuple[Tensor, dict]

gen_loss(disc_pred_fake: torch.Tensor) Tuple[torch.Tensor, dict]

Get disc loss. SAGAN, SNGAN and Proj-GAN use hinge loss to train the generator.

Parameters

disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

Returns

Loss value and a dict of log variables.

Return type

Tuple[Tensor, dict]

train_discriminator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train discriminator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

train_generator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train generator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

class mmedit.models.editors.SinGAN(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, num_scales: Optional[int] = None, iters_per_scale: int = 2000, noise_weight_init: int = 0.1, lr_scheduler_args: Optional[dict] = None, test_pkl_data: Optional[str] = None, ema_confg: Optional[dict] = None)

Bases: mmedit.models.base_models.BaseGAN

SinGAN.

This model implement the single image generative adversarial model proposed in: Singan: Learning a Generative Model from a Single Natural Image, ICCV’19.

Notes for training:

  • This model should be trained with our dataset SinGANDataset.

  • In training, the total_iters arguments is related to the number of scales in the image pyramid and iters_per_scale in the train_cfg. You should set it carefully in the training config file.

Notes for model architectures:

  • The generator and discriminator need num_scales in initialization. However, this arguments is generated by create_real_pyramid function from the singan_dataset.py. The last element in the returned list (stop_scale) is the value for num_scales. Pay attention that this scale is counted from zero. Please see our tutorial for SinGAN to obtain more details or our standard config for reference.

Parameters
  • generator (ModelType) – The config or model of the generator.

  • discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.

  • data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or GANDataPreprocessor.

  • generator_steps (int) – The number of times the generator is completely updated before the discriminator is updated. Defaults to 1.

  • discriminator_steps (int) – The number of times the discriminator is completely updated before the generator is updated. Defaults to 1.

  • num_scales (int) – The number of scales/stages in generator/ discriminator. Note that this number is counted from zero, which is the same as the original paper. Defaults to None.

  • iters_per_scale (int) – The training iteration for each resolution scale. Defaults to 2000.

  • noise_weight_init (float) – The initialize weight of fixed noise. Defaults to 0.1

  • lr_scheduler_args (Optional[dict]) – Arguments for learning schedulers. Note that in SinGAN, we use MultiStepLR, which is the same as the original paper. If not passed, no learning schedule will be used. Defaults to None.

  • test_pkl_data (Optional[str]) – The path of pickle file which contains fixed noise and noise weight. This is must for test. Defaults to None.

  • ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.

load_test_pkl()

Load pickle for test.

_from_numpy(data: Tuple[list, numpy.ndarray]) Tuple[torch.Tensor, List[torch.Tensor]]

Convert input numpy array or list of numpy array to Tensor or list of Tensor.

Parameters

data (Tuple[list, np.ndarray]) – Input data to convert.

Returns

Converted Tensor or list of tensor.

Return type

Tuple[Tensor, List[Tensor]]

get_module(model: torch.nn.Module, module_name: str) torch.nn.Module

Get an inner module from model.

Since we will wrapper DDP for some model, we have to judge whether the module can be indexed directly.

Parameters
  • model (nn.Module) – This model may wrapped with DDP or not.

  • module_name (str) – The name of specific module.

Returns

Returned sub module.

Return type

nn.Module

construct_fixed_noises()

Construct the fixed noises list used in SinGAN.

forward(inputs: mmedit.utils.ForwardInputs, data_samples: Optional[list] = None, mode=None) List[mmedit.structures.EditDataSample]

Forward function for SinGAN. For SinGAN, inputs should be a dict contains ‘num_batches’, ‘mode’ and other input arguments for the generator.

Parameters
  • inputs (dict) – Dict containing the necessary information (e.g., noise, num_batches, mode) to generate image.

  • data_samples (Optional[list]) – Data samples collated by data_preprocessor. Defaults to None.

  • mode (Optional[str]) – mode is not used in BaseConditionalGAN. Defaults to None.

gen_loss(disc_pred_fake: torch.Tensor, recon_imgs: torch.Tensor) Tuple[torch.Tensor, dict]

Generator loss for SinGAN. SinGAN use WGAN’s loss and MSE loss to train the generator.

Parameters
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • recon_imgs (Tensor) – Reconstructive images.

Returns

Loss value and a dict of log variables.

Return type

Tuple[Tensor, dict]

disc_loss(disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor, fake_data: torch.Tensor, real_data: torch.Tensor) Tuple[torch.Tensor, dict]

Get disc loss. SAGAN, SNGAN and Proj-GAN use hinge loss to train the generator.

Parameters
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • disc_pred_real (Tensor) – Discriminator’s prediction of the real images.

  • fake_data (Tensor) – Generated images, used to calculate gradient penalty.

  • real_data (Tensor) – Real images, used to calculate gradient penalty.

Returns

Loss value and a dict of log variables.

Return type

Tuple[Tensor, dict]

train_generator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train generator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

train_discriminator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train discriminator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

train_gan(inputs_dict: dict, data_sample: List[mmedit.structures.EditDataSample], optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor]

Train GAN model. In the training of GAN models, generator and discriminator are updated alternatively. In MMGeneration’s design, self.train_step is called with data input. Therefore we always update discriminator, whose updating is relay on real data, and then determine if the generator needs to be updated based on the current number of iterations. More details about whether to update generator can be found in should_gen_update().

Parameters
  • data (dict) – Data sampled from dataloader.

  • data_sample (List[EditDataSample]) – List of data sample contains GT and meta information.

  • optim_wrapper (OptimWrapperDict) – OptimWrapperDict instance contains OptimWrapper of generator and discriminator.

Returns

A dict of tensor for logging.

Return type

Dict[str, torch.Tensor]

train_step(data: dict, optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor]

Train step for SinGAN model. SinGAN is trained with multi-resolution images, and each resolution is trained for :attr:self.iters_per_scale times.

We initialize the weight and learning rate scheduler of the corresponding module at the start of each resolution’s training. At the end of each resolution’s training, we update the weight of the noise of current resolution by mse loss between reconstruced image and real image.

Parameters
  • data (dict) – Data sampled from dataloader.

  • optim_wrapper (OptimWrapperDict) – OptimWrapperDict instance contains OptimWrapper of generator and discriminator.

Returns

A dict of tensor for logging.

Return type

Dict[str, torch.Tensor]

test_step(data: dict) mmedit.utils.SampleList

Gets the generated image of given data in test progress. Before generate images, we call :meth:self.load_test_pkl to load the fixed noise and current stage of the model from the pickle file.

Parameters

data (dict) – Data sampled from metric specific sampler. More detials in Metrics and Evaluator.

Returns

A list of EditDataSample contain generated results.

Return type

SampleList

class mmedit.models.editors.SRCNNNet(channels=(3, 64, 32, 3), kernel_sizes=(9, 1, 5), upscale_factor=4)

Bases: mmengine.model.BaseModule

SRCNN network structure for image super resolution.

SRCNN has three conv layers. For each layer, we can define the in_channels, out_channels and kernel_size. The input image will first be upsampled with a bicubic upsampler, and then super-resolved in the HR spatial size.

Paper: Learning a Deep Convolutional Network for Image Super-Resolution.

Parameters
  • channels (tuple[int]) – A tuple of channel numbers for each layer including channels of input and output . Default: (3, 64, 32, 3).

  • kernel_sizes (tuple[int]) – A tuple of kernel sizes for each conv layer. Default: (9, 1, 5).

  • upscale_factor (int) – Upsampling factor. Default: 4.

forward(x)

Forward function.

Parameters

x (Tensor) – Input tensor with shape (n, c, h, w).

Returns

Forward results.

Return type

Tensor

class mmedit.models.editors.SRGAN(generator, discriminator=None, gan_loss=None, pixel_loss=None, perceptual_loss=None, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None)

Bases: mmedit.models.base_models.BaseEditModel

SRGAN model for single image super-resolution.

Ref: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network.

Parameters
  • generator (dict) – Config for the generator.

  • discriminator (dict) – Config for the discriminator. Default: None.

  • gan_loss (dict) – Config for the gan loss. Note that the loss weight in gan loss is only for the generator.

  • pixel_loss (dict) – Config for the pixel loss. Default: None.

  • perceptual_loss (dict) – Config for the perceptual loss. Default: None.

  • train_cfg (dict) – Config for training. Default: None.

  • test_cfg (dict) – Config for testing. Default: None.

  • init_cfg (dict, optional) – The weight initialized config for BaseModule. Default: None.

  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor. Default: None.

forward_train(inputs, data_samples=None, **kwargs)

Forward training. Losses of training is calculated in train_step.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

Returns

Result of forward_tensor with training=True.

Return type

Tensor

forward_tensor(inputs, data_samples=None, training=False)

Forward tensor. Returns result of simple forward.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

  • training (bool) – Whether is training. Default: False.

Returns

result of simple forward.

Return type

Tensor

if_run_g()

Calculates whether need to run the generator step.

if_run_d()

Calculates whether need to run the discriminator step.

g_step(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor)

G step of GAN: Calculate losses of generator.

Parameters
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

Returns

Dict of losses.

Return type

dict

d_step_real(batch_outputs, batch_gt_data: torch.Tensor)

Real part of D step.

Parameters
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

Returns

Real part of gan_loss for discriminator.

Return type

Tensor

d_step_fake(batch_outputs: torch.Tensor, batch_gt_data)

Fake part of D step.

Parameters
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

Returns

Fake part of gan_loss for discriminator.

Return type

Tensor

g_step_with_optim(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor, optim_wrapper: mmengine.optim.OptimWrapperDict)

G step with optim of GAN: Calculate losses of generator and run optim.

Parameters
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

  • optim_wrapper (OptimWrapperDict) – Optim wrapper dict.

Returns

Dict of parsed losses.

Return type

dict

d_step_with_optim(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor, optim_wrapper: mmengine.optim.OptimWrapperDict)

D step with optim of GAN: Calculate losses of discriminator and run optim.

Parameters
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

  • optim_wrapper (OptimWrapperDict) – Optim wrapper dict.

Returns

Dict of parsed losses.

Return type

dict

extract_gt_data(data_samples)

extract gt data from data samples.

Parameters

data_samples (list) – List of EditDataSample.

Returns

Extract gt data.

Return type

Tensor

train_step(data: List[dict], optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor]

Train step of GAN-based method.

Parameters
  • data (List[dict]) – Data sampled from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, torch.Tensor]

class mmedit.models.editors.ModifiedVGG(in_channels, mid_channels)

Bases: mmengine.model.BaseModule

A modified VGG discriminator with input size 128 x 128.

It is used to train SRGAN and ESRGAN.

Parameters
  • in_channels (int) – Channel number of inputs. Default: 3.

  • mid_channels (int) – Channel number of base intermediate features. Default: 64.

forward(x)

Forward function.

Parameters

x (Tensor) – Input tensor with shape (n, c, h, w).

Returns

Forward results.

Return type

Tensor

class mmedit.models.editors.MSRResNet(in_channels, out_channels, mid_channels=64, num_blocks=16, upscale_factor=4)

Bases: mmengine.model.BaseModule

Modified SRResNet.

A compacted version modified from SRResNet in “Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network”.

It uses residual blocks without BN, similar to EDSR. Currently, it supports x2, x3 and x4 upsampling scale factor.

Parameters
  • in_channels (int) – Channel number of inputs.

  • out_channels (int) – Channel number of outputs.

  • mid_channels (int) – Channel number of intermediate features. Default: 64.

  • num_blocks (int) – Block number in the trunk network. Default: 16.

  • upscale_factor (int) – Upsampling factor. Support x2, x3 and x4. Default: 4.

_supported_upscale_factors = [2, 3, 4]
forward(x)

Forward function.

Parameters

x (Tensor) – Input tensor with shape (n, c, h, w).

Returns

Forward results.

Return type

Tensor

init_weights()

Init weights for models.

class mmedit.models.editors.StyleGAN1(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, style_channels: int = 512, nkimgs_per_scale: dict = {}, interp_real: Optional[dict] = None, transition_kimgs: int = 600, prev_stage: int = 0, ema_config: Optional[Dict] = None)

Bases: mmedit.models.editors.pggan.ProgressiveGrowingGAN

Implementation of A Style-Based Generator Architecture for Generative Adversarial Networks.

<https://openaccess.thecvf.com/content_CVPR_2019/html/Karras_A_Style-Based_Generator_Architecture_for_Generative_Adversarial_Networks_CVPR_2019_paper.html>`_ # noqa (StyleGANv1). This class is inheriant from :class:~`ProgressiveGrowingGAN` to support progressive training.

Detailed architecture can be found in :class:~`mmgen.models.architectures.stylegan.generator_discriminator_v1.StyleGANv1Generator` # noqa and :class:~`mmgen.models.architectures.stylegan.generator_discriminator_v1.StyleGAN1Discriminator` # noqa

Parameters
  • generator (ModelType) – The config or model of the generator.

  • discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.

  • data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or GenDataPreprocessor.

  • style_channels (int) – The number of channels for style code. Defaults to 128.

  • nkimgs_per_scale (dict) – The number of images need for each resolution’s training. Defaults to {}.

  • intep_real (dict, optional) – The config of interpolation method for real images. If not passed, bilinear interpolation with align_corners will be used. Defaults to None.

  • transition_kimgs (int, optional) – The number of images during used to transit from the previous torgb layer to newer torgb layer. Defaults to 600.

  • prev_stage (int, optional) – The resolution of previous stage. Used for resume training. Defaults to 0.

  • ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.

disc_loss(disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor, fake_data: torch.Tensor, real_data: torch.Tensor) Tuple[torch.Tensor, dict]

Get disc loss. StyleGANv1 use non-saturating gan loss and R1 gradient penalty. loss to train the discriminator.

Parameters
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • disc_pred_real (Tensor) – Discriminator’s prediction of the real images.

  • fake_data (Tensor) – Generated images, used to calculate gradient penalty.

  • real_data (Tensor) – Real images, used to calculate gradient penalty.

Returns

Loss value and a dict of log variables.

Return type

Tuple[Tensor, dict]

gen_loss(disc_pred_fake: torch.Tensor) Tuple[torch.Tensor, dict]

Generator loss for PGGAN. PGGAN use WGAN’s loss to train the generator.

Parameters

disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

Returns

Loss value and a dict of log variables.

Return type

Tuple[Tensor, dict]

class mmedit.models.editors.StyleGAN2(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, ema_config: Optional[Dict] = None, loss_config=dict())

Bases: mmedit.models.base_models.BaseGAN

Impelmentation of Analyzing and Improving the Image Quality of Stylegan. # noqa.

Paper link: https://openaccess.thecvf.com/content_CVPR_2020/html/Karras_Analyzing_and_Improving_the_Image_Quality_of_StyleGAN_CVPR_2020_paper.html. # noqa

:class:~`mmgen.models.architectures.stylegan.generator_discriminator_v2.StyleGANv2Generator` # noqa and :class:~`mmgen.models.architectures.stylegan.generator_discriminator_v2.StyleGAN2Discriminator` # noqa

Parameters
  • generator (ModelType) – The config or model of the generator.

  • discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.

  • data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or GenDataPreprocessor.

  • generator_steps (int) – The number of times the generator is completely updated before the discriminator is updated. Defaults to 1.

  • discriminator_steps (int) – The number of times the discriminator is completely updated before the generator is updated. Defaults to 1.

  • ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.

disc_loss(disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor, real_imgs: torch.Tensor) Tuple
Get disc loss. StyleGANv2 use the non-saturating loss and R1

gradient penalty to train the discriminator.

Parameters
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • disc_pred_real (Tensor) – Discriminator’s prediction of the real images.

  • real_imgs (Tensor) – Input real images.

Returns

Loss value and a dict of log variables.

Return type

tuple[Tensor, dict]

gen_loss(disc_pred_fake: torch.Tensor, batch_size: int) Tuple

Get gen loss. StyleGANv2 use the non-saturating loss and generator path regularization to train the generator.

Parameters
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • batch_size (int) – Batch size for generating fake images.

Returns

Loss value and a dict of log variables.

Return type

tuple[Tensor, dict]

train_discriminator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train discriminator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

train_generator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train generator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

train_step(data: dict, optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor]

Train GAN model. In the training of GAN models, generator and discriminator are updated alternatively. In MMEditing’s design, self.train_step is called with data input. Therefore we always update discriminator, whose updating is relay on real data, and then determine if the generator needs to be updated based on the current number of iterations. More details about whether to update generator can be found in should_gen_update().

Parameters
  • data (dict) – Data sampled from dataloader.

  • optim_wrapper (OptimWrapperDict) – OptimWrapperDict instance contains OptimWrapper of generator and discriminator.

Returns

A dict of tensor for logging.

Return type

Dict[str, torch.Tensor]

class mmedit.models.editors.StyleGAN3(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, forward_kwargs: Optional[Dict] = None, ema_config: Optional[Dict] = None, loss_config=dict())

Bases: mmedit.models.editors.stylegan2.StyleGAN2

Impelmentation of Alias-Free Generative Adversarial Networks. # noqa.

Paper link: https://nvlabs-fi-cdn.nvidia.com/stylegan3/stylegan3-paper.pdf # noqa

Detailed architecture can be found in

:class:~`mmgen.models.architectures.stylegan.generator_discriminator_v3.StyleGANv3Generator` # noqa and :class:~`mmgen.models.architectures.stylegan.generator_discriminator_v2.StyleGAN2Discriminator` # noqa

test_step(data: dict) mmedit.utils.typing.SampleList

Gets the generated image of given data. Same as val_step().

Parameters

data (dict) – Data sampled from metric specific sampler. More detials in Metrics and Evaluator.

Returns

A list of EditDataSample contain generated results.

Return type

SampleList

val_step(data: dict) mmedit.utils.typing.SampleList

Gets the generated image of given data. Same as val_step().

Parameters

data (dict) – Data sampled from metric specific sampler. More detials in Metrics and Evaluator.

Returns

A list of EditDataSample contain generated results.

Return type

SampleList

train_discriminator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train discriminator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

train_generator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train generator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

sample_equivarience_pairs(batch_size, sample_mode='ema', eq_cfg=dict(compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False, translate_max=0.125, rotate_max=1), sample_kwargs=dict())
class mmedit.models.editors.StyleGAN3Generator(out_size, style_channels, img_channels, noise_size=512, rgb2bgr=False, pretrained=None, synthesis_cfg=dict(type='SynthesisNetwork'), mapping_cfg=dict(type='MappingNetwork'))

Bases: torch.nn.Module

StyleGAN3 Generator.

In StyleGAN3, we make several changes to StyleGANv2’s generator which include transformed fourier features, filtered nonlinearities and non-critical sampling, etc. More details can be found in: Alias-Free Generative Adversarial Networks NeurIPS’2021.

Ref: https://github.com/NVlabs/stylegan3

Parameters
  • out_size (int) – The output size of the StyleGAN3 generator.

  • style_channels (int) – The number of channels for style code.

  • img_channels (int) – The number of output’s channels.

  • noise_size (int, optional) – Size of the input noise vector. Defaults to 512.

  • rgb2bgr (bool, optional) – Whether to reformat the output channels with order bgr. We provide several pre-trained StyleGAN3 weights whose output channels order is rgb. You can set this argument to True to use the weights.

  • pretrained (str | dict, optional) – Path for the pretrained model or dict containing information for pretained models whose necessary key is ‘ckpt_path’. Besides, you can also provide ‘prefix’ to load the generator part from the whole state dict. Defaults to None.

  • synthesis_cfg (dict, optional) – Config for synthesis network. Defaults to dict(type=’SynthesisNetwork’).

  • mapping_cfg (dict, optional) – Config for mapping network. Defaults to dict(type=’MappingNetwork’).

_load_pretrained_model(ckpt_path, prefix='', map_location='cpu', strict=True)
forward(noise, num_batches=0, input_is_latent=False, truncation=1, num_truncation_layer=None, update_emas=False, force_fp32=True, return_noise=False, return_latents=False)

Forward Function for stylegan3.

Parameters
  • noise (torch.Tensor | callable | None) – You can directly give a batch of noise through a torch.Tensor or offer a callable function to sample a batch of noise data. Otherwise, the None indicates to use the default noise sampler.

  • num_batches (int, optional) – The number of batch size. Defaults to 0.

  • input_is_latent (bool, optional) – If True, the input tensor is the latent tensor. Defaults to False.

  • truncation (float, optional) – Truncation factor. Give value less than 1., the truncation trick will be adopted. Defaults to 1.

  • num_truncation_layer (int, optional) – Number of layers use truncated latent. Defaults to None.

  • update_emas (bool, optional) – Whether update moving average of mean latent. Defaults to False.

  • force_fp32 (bool, optional) – Force fp32 ignore the weights. Defaults to True.

  • return_noise (bool, optional) – If True, noise_batch will be returned in a dict with fake_img. Defaults to False.

  • return_latents (bool, optional) – If True, latent will be returned in a dict with fake_img. Defaults to False.

Returns

Generated image tensor or dictionary containing more data.

Return type

torch.Tensor | dict

get_mean_latent(num_samples=4096, **kwargs)

Get mean latent of W space in this generator.

Parameters

num_samples (int, optional) – Number of sample times. Defaults to 4096.

Returns

Mean latent of this generator.

Return type

Tensor

get_training_kwargs(phase)

Get training kwargs. In StyleGANv3, we enable fp16, and update mangitude ema during training of discriminator. This function is used to pass related arguments.

Parameters

phase (str) – Current training phase.

Returns

Training kwargs.

Return type

dict

class mmedit.models.editors.TDAN(generator, pixel_loss, lq_pixel_loss, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None)

Bases: mmedit.models.BaseEditModel

TDAN model for video super-resolution.

Paper:

TDAN: Temporally-Deformable Alignment Network for Video Super- Resolution, CVPR, 2020

Parameters
  • generator (dict) – Config for the generator structure.

  • pixel_loss (dict) – Config for pixel-wise loss.

  • lq_pixel_loss (dict) – Config for pixel-wise loss for the LQ images.

  • train_cfg (dict) – Config for training. Default: None.

  • test_cfg (dict) – Config for testing. Default: None.

  • init_cfg (dict, optional) – The weight initialized config for BaseModule.

  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor.

forward_train(inputs, data_samples=None, **kwargs)

Forward training. Returns dict of losses of training.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

Returns

Dict of losses.

Return type

dict

forward_tensor(inputs, data_samples=None, training=False, **kwargs)

Forward tensor. Returns result of simple forward.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

  • training (bool) – Whether is training. Default: False.

Returns

results of forward inference and

forward train.

Return type

(Tensor | List[Tensor])

class mmedit.models.editors.TDANNet(in_channels=3, mid_channels=64, out_channels=3, num_blocks_before_align=5, num_blocks_after_align=10)

Bases: mmengine.model.BaseModule

TDAN network structure for video super-resolution.

Support only x4 upsampling.

Paper:

TDAN: Temporally-Deformable Alignment Network for Video Super- Resolution, CVPR, 2020

Parameters
  • in_channels (int) – Number of channels of the input image. Default: 3.

  • mid_channels (int) – Number of channels of the intermediate features. Default: 64.

  • out_channels (int) – Number of channels of the output image. Default: 3.

  • num_blocks_before_align (int) – Number of residual blocks before temporal alignment. Default: 5.

  • num_blocks_after_align (int) – Number of residual blocks after temporal alignment. Default: 10.

forward(lrs)

Forward function for TDANNet.

Parameters

lrs (Tensor) – Input LR sequence with shape (n, t, c, h, w).

Returns

Output HR image with shape (n, c, 4h, 4w) and aligned LR images with shape (n, t, c, h, w).

Return type

tuple[Tensor]

class mmedit.models.editors.TOFlowVFINet(rgb_mean=[0.485, 0.456, 0.406], rgb_std=[0.229, 0.224, 0.225], flow_cfg=dict(norm_cfg=None, pretrained=None), init_cfg=None)

Bases: mmengine.model.BaseModule

PyTorch implementation of TOFlow for video frame interpolation.

Paper: Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018 Code reference:

  1. https://github.com/anchen1011/toflow

  2. https://github.com/Coldog2333/pytoflow

Parameters
  • rgb_mean (list[float]) – Image mean in RGB orders. Default: [0.485, 0.456, 0.406]

  • rgb_std (list[float]) – Image std in RGB orders. Default: [0.229, 0.224, 0.225]

  • flow_cfg (dict) – Config of SPyNet. Default: dict(norm_cfg=None, pretrained=None)

  • init_cfg (dict, optional) – Initialization config dict. Default: None.

forward(imgs)
Parameters

imgs – Input frames with shape of (b, 2, 3, h, w).

Returns

Interpolated frame with shape of (b, 3, h, w).

Return type

Tensor

class mmedit.models.editors.TOFlowVSRNet(adapt_official_weights=False)

Bases: mmengine.model.BaseModule

PyTorch implementation of TOFlow.

In TOFlow, the LR frames are pre-upsampled and have the same size with the GT frames.

Paper: Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018 Code reference:

  1. https://github.com/anchen1011/toflow

  2. https://github.com/Coldog2333/pytoflow

Parameters

adapt_official_weights (bool) – Whether to adapt the weights translated from the official implementation. Set to false if you want to train from scratch. Default: False

forward(lrs)
Parameters

lrs – Input lr frames: (b, 7, 3, h, w).

Returns

SR frame: (b, 3, h, w).

Return type

Tensor

class mmedit.models.editors.ToFResBlock

Bases: torch.nn.Module

ResNet architecture.

Three-layers ResNet/ResBlock

forward(frames)
Parameters

frames (Tensor) – Tensor with shape of (b, 2, 3, h, w).

Returns

Interpolated frame with shape of (b, 3, h, w).

Return type

Tensor

class mmedit.models.editors.LTE(requires_grad=True, pixel_range=1.0, load_pretrained_vgg=True)

Bases: mmengine.model.BaseModule

Learnable Texture Extractor.

Based on pretrained VGG19. Generate features in 3 levels.

Parameters
  • requires_grad (bool) – Require grad or not. Default: True.

  • pixel_range (float) – Pixel range of geature. Default: 1.

  • load_pretrained_vgg (bool) – Load pretrained VGG from torchvision. Default: True. Train: must load pretrained VGG. Eval: needn’t load pretrained VGG, because we will load pretrained LTE.

forward(x)

Forward function.

Parameters

x (Tensor) – Input tensor with shape (n, 3, h, w).

Returns

Forward results in 3 levels.

x_level3: Forward results in level 3 (n, 256, h/4, w/4). x_level2: Forward results in level 2 (n, 128, h/2, w/2). x_level1: Forward results in level 1 (n, 64, h, w).

Return type

Tuple[Tensor]

class mmedit.models.editors.TTSR(generator, extractor, transformer, pixel_loss, discriminator=None, perceptual_loss=None, transferal_perceptual_loss=None, gan_loss=None, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None)

Bases: mmedit.models.editors.srgan.SRGAN

TTSR model for Reference-based Image Super-Resolution.

Paper: Learning Texture Transformer Network for Image Super-Resolution.

Parameters
  • generator (dict) – Config for the generator.

  • extractor (dict) – Config for the extractor.

  • transformer (dict) – Config for the transformer.

  • pixel_loss (dict) – Config for the pixel loss.

  • discriminator (dict) – Config for the discriminator. Default: None.

  • perceptual_loss (dict) – Config for the perceptual loss. Default: None.

  • transferal_perceptual_loss (dict) – Config for the transferal perceptual loss. Default: None.

  • gan_loss (dict) – Config for the GAN loss. Default: None

  • train_cfg (dict) – Config for train. Default: None.

  • test_cfg (dict) – Config for testing. Default: None.

  • init_cfg (dict, optional) – The weight initialized config for BaseModule. Default: None.

  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor. Default: None.

forward_tensor(inputs, data_samples=None, training=False)

Forward tensor. Returns result of simple forward.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

  • training (bool) – Whether is training. Default: False.

Returns

results of forward inference and

forward train.

Return type

(Tensor | Tuple[List[Tensor]])

if_run_g()

Calculates whether need to run the generator step.

if_run_d()

Calculates whether need to run the discriminator step.

g_step(batch_outputs, batch_gt_data: mmedit.structures.EditDataSample)

G step of GAN: Calculate losses of generator.

Parameters
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

Returns

Dict of losses.

Return type

dict

g_step_with_optim(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor, optim_wrapper: mmengine.optim.OptimWrapperDict)

G step with optim of GAN: Calculate losses of generator and run optim.

Parameters
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

  • optim_wrapper (OptimWrapperDict) – Optim wrapper dict.

Returns

Dict of parsed losses.

Return type

dict

d_step_with_optim(batch_outputs, batch_gt_data, optim_wrapper)

D step with optim of GAN: Calculate losses of discriminator and run optim.

Parameters
  • batch_outputs (Tuple[Tensor]) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

  • optim_wrapper (OptimWrapper) – Optim wrapper of discriminator.

Returns

Dict of parsed losses.

Return type

dict

class mmedit.models.editors.SearchTransformer

Bases: torch.nn.Module

Search texture reference by transformer.

Include relevance embedding, hard-attention and soft-attention.

gather(inputs, dim, index)

Hard Attention. Gathers values along an axis specified by dim.

Parameters
  • inputs (Tensor) – The source tensor. (N, C*k*k, H*W)

  • dim (int) – The axis along which to index.

  • index (Tensor) – The indices of elements to gather. (N, H*W)

results:

outputs (Tensor): The result tensor. (N, C*k*k, H*W)

forward(img_lq, ref_lq, refs)

Texture transformer.

Q = LTE(img_lq) K = LTE(ref_lq) V = LTE(ref), from V_level_n to V_level_1

Relevance embedding aims to embed the relevance between the LQ and

Ref image by estimating the similarity between Q and K.

Hard-Attention: Only transfer features from the most relevant position

in V for each query.

Soft-Attention: synthesize features from the transferred GT texture

features T and the LQ features F from the backbone.

Parameters
  • extractor (All args are features come from) – These features contain 3 levels. When upscale_factor=4, the size ratio of these features is level3:level2:level1 = 1:2:4.

  • img_lq (Tensor) – Tensor of 4x bicubic-upsampled lq image. (N, C, H, W)

  • ref_lq (Tensor) – Tensor of ref_lq. ref_lq is obtained by applying bicubic down-sampling and up-sampling with factor 4x on ref. (N, C, H, W)

  • refs (Tuple[Tensor]) – Tuple of ref tensors. [(N, C, H, W), (N, C/2, 2H, 2W), …]

Returns

tuple contains:

soft_attention (Tensor): Soft-Attention tensor. (N, 1, H, W)

textures (Tuple[Tensor]): Transferred GT textures. [(N, C, H, W), (N, C/2, 2H, 2W), …]

Return type

tuple

class mmedit.models.editors.TTSRDiscriminator(in_channels=3, in_size=160)

Bases: mmengine.model.BaseModule

A discriminator for TTSR.

Parameters
  • in_channels (int) – Channel number of inputs. Default: 3.

  • in_size (int) – Size of input image. Default: 160.

forward(x)

Forward function.

Parameters

x (Tensor) – Input tensor with shape (n, c, h, w).

Returns

Forward results.

Return type

Tensor

class mmedit.models.editors.TTSRNet(in_channels, out_channels, mid_channels=64, texture_channels=64, num_blocks=(16, 16, 8, 4), res_scale=1.0)

Bases: mmengine.model.BaseModule

TTSR network structure (main-net) for reference-based super-resolution.

Paper: Learning Texture Transformer Network for Image Super-Resolution

Adapted from ‘https://github.com/researchmm/TTSR.git’ ‘https://github.com/researchmm/TTSR’ Copyright permission at ‘https://github.com/researchmm/TTSR/issues/38’.

Parameters
  • in_channels (int) – Number of channels in the input image

  • out_channels (int) – Number of channels in the output image

  • mid_channels (int) – Channel number of intermediate features. Default: 64

  • texture_channels (int) – Number of texture channels. Default: 64.

  • num_blocks (tuple[int]) – Block numbers in the trunk network. Default: (16, 16, 8, 4)

  • res_scale (float) – Used to scale the residual in residual block. Default: 1.

forward(x, soft_attention, textures)

Forward function.

Parameters
  • x (Tensor) – Input tensor with shape (n, c, h, w).

  • soft_attention (Tensor) – Soft-Attention tensor with shape (n, 1, h, w).

  • textures (Tuple[Tensor]) – Transferred HR texture tensors. [(N, C, H, W), (N, C/2, 2H, 2W), …]

Returns

Forward results.

Return type

Tensor

class mmedit.models.editors.WGANGP(*args, **kwargs)

Bases: mmedit.models.base_models.BaseGAN

Impelmentation of Improved Training of Wasserstein GANs.

Paper link: https://arxiv.org/pdf/1704.00028

Detailed architecture can be found in :class:~`mmgen.models.architectures.wgan_gp.generator_discriminator.WGANGPGenerator` # noqa and :class:~`mmgen.models.architectures.wgan_gp.generator_discriminator.WGANGPDiscriminator` # noqa

disc_loss(real_data: torch.Tensor, fake_data: torch.Tensor, disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor) Tuple

Get disc loss. WGAN-GP use the wgan loss and gradient penalty to train the discriminator.

Parameters
  • real_data (Tensor) – Real input data.

  • fake_data (Tensor) – Fake input data.

  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • disc_pred_real (Tensor) – Discriminator’s prediction of the real images.

Returns

Loss value and a dict of log variables.

Return type

tuple[Tensor, dict]

gen_loss(disc_pred_fake: torch.Tensor) Tuple

Get gen loss. DCGAN use the wgan loss to train the generator.

Parameters

disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

Returns

Loss value and a dict of log variables.

Return type

tuple[Tensor, dict]

train_discriminator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train discriminator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

train_generator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Train generator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

Read the Docs v: zyh/api-rendering
Versions
master
latest
stable
zyh-doc-notfound-extend
zyh-api-rendering
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.