Shortcuts

mmedit.models.editors.ttsr

Package Contents

Classes

LTE

Learnable Texture Extractor.

SearchTransformer

Search texture reference by transformer.

TTSR

TTSR model for Reference-based Image Super-Resolution.

TTSRDiscriminator

A discriminator for TTSR.

TTSRNet

TTSR network structure (main-net) for reference-based super-resolution.

class mmedit.models.editors.ttsr.LTE(requires_grad=True, pixel_range=1.0, load_pretrained_vgg=True)[source]

Bases: mmengine.model.BaseModule

Learnable Texture Extractor.

Based on pretrained VGG19. Generate features in 3 levels.

Parameters
  • requires_grad (bool) – Require grad or not. Default: True.

  • pixel_range (float) – Pixel range of geature. Default: 1.

  • load_pretrained_vgg (bool) – Load pretrained VGG from torchvision. Default: True. Train: must load pretrained VGG. Eval: needn’t load pretrained VGG, because we will load pretrained LTE.

forward(x)[source]

Forward function.

Parameters

x (Tensor) – Input tensor with shape (n, 3, h, w).

Returns

Forward results in 3 levels.

x_level3: Forward results in level 3 (n, 256, h/4, w/4). x_level2: Forward results in level 2 (n, 128, h/2, w/2). x_level1: Forward results in level 1 (n, 64, h, w).

Return type

Tuple[Tensor]

class mmedit.models.editors.ttsr.SearchTransformer[source]

Bases: torch.nn.Module

Search texture reference by transformer.

Include relevance embedding, hard-attention and soft-attention.

gather(inputs, dim, index)[source]

Hard Attention. Gathers values along an axis specified by dim.

Parameters
  • inputs (Tensor) – The source tensor. (N, C*k*k, H*W)

  • dim (int) – The axis along which to index.

  • index (Tensor) – The indices of elements to gather. (N, H*W)

results:

outputs (Tensor): The result tensor. (N, C*k*k, H*W)

forward(img_lq, ref_lq, refs)[source]

Texture transformer.

Q = LTE(img_lq) K = LTE(ref_lq) V = LTE(ref), from V_level_n to V_level_1

Relevance embedding aims to embed the relevance between the LQ and

Ref image by estimating the similarity between Q and K.

Hard-Attention: Only transfer features from the most relevant position

in V for each query.

Soft-Attention: synthesize features from the transferred GT texture

features T and the LQ features F from the backbone.

Parameters
  • extractor (All args are features come from) – These features contain 3 levels. When upscale_factor=4, the size ratio of these features is level3:level2:level1 = 1:2:4.

  • img_lq (Tensor) – Tensor of 4x bicubic-upsampled lq image. (N, C, H, W)

  • ref_lq (Tensor) – Tensor of ref_lq. ref_lq is obtained by applying bicubic down-sampling and up-sampling with factor 4x on ref. (N, C, H, W)

  • refs (Tuple[Tensor]) – Tuple of ref tensors. [(N, C, H, W), (N, C/2, 2H, 2W), …]

Returns

tuple contains:

soft_attention (Tensor): Soft-Attention tensor. (N, 1, H, W)

textures (Tuple[Tensor]): Transferred GT textures. [(N, C, H, W), (N, C/2, 2H, 2W), …]

Return type

tuple

class mmedit.models.editors.ttsr.TTSR(generator, extractor, transformer, pixel_loss, discriminator=None, perceptual_loss=None, transferal_perceptual_loss=None, gan_loss=None, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None)[source]

Bases: mmedit.models.editors.srgan.SRGAN

TTSR model for Reference-based Image Super-Resolution.

Paper: Learning Texture Transformer Network for Image Super-Resolution.

Parameters
  • generator (dict) – Config for the generator.

  • extractor (dict) – Config for the extractor.

  • transformer (dict) – Config for the transformer.

  • pixel_loss (dict) – Config for the pixel loss.

  • discriminator (dict) – Config for the discriminator. Default: None.

  • perceptual_loss (dict) – Config for the perceptual loss. Default: None.

  • transferal_perceptual_loss (dict) – Config for the transferal perceptual loss. Default: None.

  • gan_loss (dict) – Config for the GAN loss. Default: None

  • train_cfg (dict) – Config for train. Default: None.

  • test_cfg (dict) – Config for testing. Default: None.

  • init_cfg (dict, optional) – The weight initialized config for BaseModule. Default: None.

  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor. Default: None.

forward_tensor(inputs, data_samples=None, training=False)[source]

Forward tensor. Returns result of simple forward.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

  • training (bool) – Whether is training. Default: False.

Returns

results of forward inference and

forward train.

Return type

(Tensor | Tuple[List[Tensor]])

if_run_g()[source]

Calculates whether need to run the generator step.

if_run_d()[source]

Calculates whether need to run the discriminator step.

g_step(batch_outputs, batch_gt_data: mmedit.structures.EditDataSample)[source]

G step of GAN: Calculate losses of generator.

Parameters
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

Returns

Dict of losses.

Return type

dict

g_step_with_optim(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor, optim_wrapper: mmengine.optim.OptimWrapperDict)[source]

G step with optim of GAN: Calculate losses of generator and run optim.

Parameters
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

  • optim_wrapper (OptimWrapperDict) – Optim wrapper dict.

Returns

Dict of parsed losses.

Return type

dict

d_step_with_optim(batch_outputs, batch_gt_data, optim_wrapper)[source]

D step with optim of GAN: Calculate losses of discriminator and run optim.

Parameters
  • batch_outputs (Tuple[Tensor]) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

  • optim_wrapper (OptimWrapper) – Optim wrapper of discriminator.

Returns

Dict of parsed losses.

Return type

dict

class mmedit.models.editors.ttsr.TTSRDiscriminator(in_channels=3, in_size=160)[source]

Bases: mmengine.model.BaseModule

A discriminator for TTSR.

Parameters
  • in_channels (int) – Channel number of inputs. Default: 3.

  • in_size (int) – Size of input image. Default: 160.

forward(x)[source]

Forward function.

Parameters

x (Tensor) – Input tensor with shape (n, c, h, w).

Returns

Forward results.

Return type

Tensor

class mmedit.models.editors.ttsr.TTSRNet(in_channels, out_channels, mid_channels=64, texture_channels=64, num_blocks=(16, 16, 8, 4), res_scale=1.0)[source]

Bases: mmengine.model.BaseModule

TTSR network structure (main-net) for reference-based super-resolution.

Paper: Learning Texture Transformer Network for Image Super-Resolution

Adapted from ‘https://github.com/researchmm/TTSR.git’ ‘https://github.com/researchmm/TTSR’ Copyright permission at ‘https://github.com/researchmm/TTSR/issues/38’.

Parameters
  • in_channels (int) – Number of channels in the input image

  • out_channels (int) – Number of channels in the output image

  • mid_channels (int) – Channel number of intermediate features. Default: 64

  • texture_channels (int) – Number of texture channels. Default: 64.

  • num_blocks (tuple[int]) – Block numbers in the trunk network. Default: (16, 16, 8, 4)

  • res_scale (float) – Used to scale the residual in residual block. Default: 1.

forward(x, soft_attention, textures)[source]

Forward function.

Parameters
  • x (Tensor) – Input tensor with shape (n, c, h, w).

  • soft_attention (Tensor) – Soft-Attention tensor with shape (n, 1, h, w).

  • textures (Tuple[Tensor]) – Transferred HR texture tensors. [(N, C, H, W), (N, C/2, 2H, 2W), …]

Returns

Forward results.

Return type

Tensor

Read the Docs v: latest
Versions
master
latest
stable
zyh-re-docs
zyh-doc-notfound-extend
zyh-api-rendering
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.