Shortcuts

mmedit.models.editors.basicvsr

Package Contents

Classes

BasicVSR

BasicVSR model for video super-resolution.

BasicVSRNet

BasicVSR network structure for video super-resolution.

class mmedit.models.editors.basicvsr.BasicVSR(generator, pixel_loss, ensemble=None, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None)[源代码]

Bases: mmedit.models.BaseEditModel

BasicVSR model for video super-resolution.

Note that this model is used for IconVSR.

Paper:

BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond, CVPR, 2021

参数
  • generator (dict) – Config for the generator structure.

  • pixel_loss (dict) – Config for pixel-wise loss.

  • ensemble (dict) – Config for ensemble. Default: None.

  • train_cfg (dict) – Config for training. Default: None.

  • test_cfg (dict) – Config for testing. Default: None.

  • init_cfg (dict, optional) – The weight initialized config for BaseModule.

  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor.

check_if_mirror_extended(lrs)

Check whether the input is a mirror-extended sequence.

If mirror-extended, the i-th (i=0, …, t-1) frame is equal to the (t-1-i)-th frame.

参数

lrs (tensor) – Input LR images with shape (n, t, c, h, w)

forward_train(inputs, data_samples=None, **kwargs)

Forward training. Returns dict of losses of training.

参数
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

返回

Dict of losses.

返回类型

dict

forward_inference(inputs, data_samples=None, **kwargs)

Forward inference. Returns predictions of validation, testing.

参数
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

返回

predictions.

返回类型

List[EditDataSample]

class mmedit.models.editors.basicvsr.BasicVSRNet(mid_channels=64, num_blocks=30, spynet_pretrained=None)[源代码]

Bases: mmengine.model.BaseModule

BasicVSR network structure for video super-resolution.

Support only x4 upsampling.

Paper:

BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond, CVPR, 2021

参数
  • mid_channels (int) – Channel number of the intermediate features. Default: 64.

  • num_blocks (int) – Number of residual blocks in each propagation branch. Default: 30.

  • spynet_pretrained (str) – Pre-trained model path of SPyNet. Default: None.

check_if_mirror_extended(lrs)

Check whether the input is a mirror-extended sequence.

If mirror-extended, the i-th (i=0, …, t-1) frame is equal to the (t-1-i)-th frame.

参数

lrs (tensor) – Input LR images with shape (n, t, c, h, w)

compute_flow(lrs)

Compute optical flow using SPyNet for feature warping.

Note that if the input is an mirror-extended sequence, ‘flows_forward’ is not needed, since it is equal to ‘flows_backward.flip(1)’.

参数

lrs (tensor) – Input LR images with shape (n, t, c, h, w)

返回

Optical flow. ‘flows_forward’ corresponds to the

flows used for forward-time propagation (current to previous). ‘flows_backward’ corresponds to the flows used for backward-time propagation (current to next).

返回类型

tuple(Tensor)

forward(lrs)

Forward function for BasicVSR.

参数

lrs (Tensor) – Input LR sequence with shape (n, t, c, h, w).

返回

Output HR sequence with shape (n, t, c, 4h, 4w).

返回类型

Tensor

Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.