Shortcuts

mmedit.models.editors.srgan

Package Contents

Classes

ModifiedVGG

A modified VGG discriminator with input size 128 x 128.

MSRResNet

Modified SRResNet.

SRGAN

SRGAN model for single image super-resolution.

class mmedit.models.editors.srgan.ModifiedVGG(in_channels, mid_channels)[源代码]

Bases: mmengine.model.BaseModule

A modified VGG discriminator with input size 128 x 128.

It is used to train SRGAN and ESRGAN.

参数
  • in_channels (int) – Channel number of inputs. Default: 3.

  • mid_channels (int) – Channel number of base intermediate features. Default: 64.

forward(x)

Forward function.

参数

x (Tensor) – Input tensor with shape (n, c, h, w).

返回

Forward results.

返回类型

Tensor

class mmedit.models.editors.srgan.MSRResNet(in_channels, out_channels, mid_channels=64, num_blocks=16, upscale_factor=4)[源代码]

Bases: mmengine.model.BaseModule

Modified SRResNet.

A compacted version modified from SRResNet in “Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network”.

It uses residual blocks without BN, similar to EDSR. Currently, it supports x2, x3 and x4 upsampling scale factor.

参数
  • in_channels (int) – Channel number of inputs.

  • out_channels (int) – Channel number of outputs.

  • mid_channels (int) – Channel number of intermediate features. Default: 64.

  • num_blocks (int) – Block number in the trunk network. Default: 16.

  • upscale_factor (int) – Upsampling factor. Support x2, x3 and x4. Default: 4.

_supported_upscale_factors = [2, 3, 4]
forward(x)

Forward function.

参数

x (Tensor) – Input tensor with shape (n, c, h, w).

返回

Forward results.

返回类型

Tensor

init_weights()

Init weights for models.

class mmedit.models.editors.srgan.SRGAN(generator, discriminator=None, gan_loss=None, pixel_loss=None, perceptual_loss=None, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None)[源代码]

Bases: mmedit.models.base_models.BaseEditModel

SRGAN model for single image super-resolution.

Ref: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network.

参数
  • generator (dict) – Config for the generator.

  • discriminator (dict) – Config for the discriminator. Default: None.

  • gan_loss (dict) – Config for the gan loss. Note that the loss weight in gan loss is only for the generator.

  • pixel_loss (dict) – Config for the pixel loss. Default: None.

  • perceptual_loss (dict) – Config for the perceptual loss. Default: None.

  • train_cfg (dict) – Config for training. Default: None.

  • test_cfg (dict) – Config for testing. Default: None.

  • init_cfg (dict, optional) – The weight initialized config for BaseModule. Default: None.

  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor. Default: None.

forward_train(inputs, data_samples=None, **kwargs)

Forward training. Losses of training is calculated in train_step.

参数
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

返回

Result of forward_tensor with training=True.

返回类型

Tensor

forward_tensor(inputs, data_samples=None, training=False)

Forward tensor. Returns result of simple forward.

参数
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

  • training (bool) – Whether is training. Default: False.

返回

result of simple forward.

返回类型

Tensor

if_run_g()

Calculates whether need to run the generator step.

if_run_d()

Calculates whether need to run the discriminator step.

g_step(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor)

G step of GAN: Calculate losses of generator.

参数
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

返回

Dict of losses.

返回类型

dict

d_step_real(batch_outputs, batch_gt_data: torch.Tensor)

Real part of D step.

参数
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

返回

Real part of gan_loss for discriminator.

返回类型

Tensor

d_step_fake(batch_outputs: torch.Tensor, batch_gt_data)

Fake part of D step.

参数
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

返回

Fake part of gan_loss for discriminator.

返回类型

Tensor

g_step_with_optim(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor, optim_wrapper: mmengine.optim.OptimWrapperDict)

G step with optim of GAN: Calculate losses of generator and run optim.

参数
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

  • optim_wrapper (OptimWrapperDict) – Optim wrapper dict.

返回

Dict of parsed losses.

返回类型

dict

d_step_with_optim(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor, optim_wrapper: mmengine.optim.OptimWrapperDict)

D step with optim of GAN: Calculate losses of discriminator and run optim.

参数
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

  • optim_wrapper (OptimWrapperDict) – Optim wrapper dict.

返回

Dict of parsed losses.

返回类型

dict

extract_gt_data(data_samples)

extract gt data from data samples.

参数

data_samples (list) – List of EditDataSample.

返回

Extract gt data.

返回类型

Tensor

train_step(data: List[dict], optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor]

Train step of GAN-based method.

参数
  • data (List[dict]) – Data sampled from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

返回

A dict of tensor for logging.

返回类型

Dict[str, torch.Tensor]

Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.