Shortcuts

mmedit.models.editors.real_esrgan

Package Contents

Classes

RealESRGAN

Real-ESRGAN model for single image super-resolution.

UNetDiscriminatorWithSpectralNorm

A U-Net discriminator with spectral normalization.

class mmedit.models.editors.real_esrgan.RealESRGAN(generator, discriminator=None, gan_loss=None, pixel_loss=None, perceptual_loss=None, is_use_sharpened_gt_in_pixel=False, is_use_sharpened_gt_in_percep=False, is_use_sharpened_gt_in_gan=False, is_use_ema=True, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None)[source]

Bases: mmedit.models.editors.srgan.SRGAN

Real-ESRGAN model for single image super-resolution.

Ref: Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data, 2021.

Note: generator_ema is realized in EMA_HOOK

Parameters
  • generator (dict) – Config for the generator.

  • discriminator (dict, optional) – Config for the discriminator. Default: None.

  • gan_loss (dict, optional) – Config for the gan loss. Note that the loss weight in gan loss is only for the generator.

  • pixel_loss (dict, optional) – Config for the pixel loss. Default: None.

  • perceptual_loss (dict, optional) – Config for the perceptual loss. Default: None.

  • is_use_sharpened_gt_in_pixel (bool, optional) – Whether to use the image sharpened by unsharp masking as the GT for pixel loss. Default: False.

  • is_use_sharpened_gt_in_percep (bool, optional) – Whether to use the image sharpened by unsharp masking as the GT for perceptual loss. Default: False.

  • is_use_sharpened_gt_in_gan (bool, optional) – Whether to use the image sharpened by unsharp masking as the GT for adversarial loss. Default: False.

  • is_use_ema (bool, optional) – When to apply exponential moving average on the network weights. Default: True.

  • train_cfg (dict) – Config for training. Default: None. You may change the training of gan by setting: disc_steps: how many discriminator updates after one generate update; disc_init_steps: how many discriminator updates at the start of the training. These two keys are useful when training with WGAN.

  • test_cfg (dict) – Config for testing. Default: None.

  • init_cfg (dict, optional) – The weight initialized config for BaseModule. Default: None.

  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor. Default: None.

forward_tensor(inputs, data_samples=None, training=False)[source]

Forward tensor. Returns result of simple forward.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

  • training (bool) – Whether is training. Default: False.

Returns

result of simple forward.

Return type

Tensor

g_step(batch_outputs, batch_gt_data)[source]

G step of GAN: Calculate losses of generator.

Parameters
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tuple[Tensor]) – Batch GT data.

Returns

Dict of losses.

Return type

dict

d_step_real(batch_outputs, batch_gt_data: torch.Tensor)[source]

Real part of D step.

Parameters
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tuple[Tensor]) – Batch GT data.

Returns

Real part of gan_loss for discriminator.

Return type

Tensor

d_step_fake(batch_outputs, batch_gt_data)[source]

Fake part of D step.

Parameters
  • batch_outputs (Tensor) – Output of generator.

  • batch_gt_data (Tuple[Tensor]) – Batch GT data.

Returns

Fake part of gan_loss for discriminator.

Return type

Tensor

extract_gt_data(data_samples)[source]

extract gt data from data samples.

Parameters

data_samples (list) – List of EditDataSample.

Returns

Extract gt data.

Return type

Tensor

class mmedit.models.editors.real_esrgan.UNetDiscriminatorWithSpectralNorm(in_channels, mid_channels=64, skip_connection=True)[source]

Bases: mmengine.model.BaseModule

A U-Net discriminator with spectral normalization.

Parameters
  • in_channels (int) – Channel number of the input.

  • mid_channels (int, optional) – Channel number of the intermediate features. Default: 64.

  • skip_connection (bool, optional) – Whether to use skip connection. Default: True.

forward(img)[source]

Forward function.

Parameters

img (Tensor) – Input tensor with shape (n, c, h, w).

Returns

Forward results.

Return type

Tensor

Read the Docs v: latest
Versions
master
latest
stable
zyh-re-docs
zyh-doc-notfound-extend
zyh-api-rendering
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.