Shortcuts

mmedit.models.editors.tof.tof_vfi_net

Module Contents

Classes

TOFlowVFINet

PyTorch implementation of TOFlow for video frame interpolation.

BasicModule

Basic module of SPyNet.

SPyNet

SPyNet architecture.

ToFResBlock

ResNet architecture.

class mmedit.models.editors.tof.tof_vfi_net.TOFlowVFINet(rgb_mean=[0.485, 0.456, 0.406], rgb_std=[0.229, 0.224, 0.225], flow_cfg=dict(norm_cfg=None, pretrained=None), init_cfg=None)[源代码]

Bases: mmengine.model.BaseModule

PyTorch implementation of TOFlow for video frame interpolation.

Paper: Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018 Code reference:

  1. https://github.com/anchen1011/toflow

  2. https://github.com/Coldog2333/pytoflow

参数
  • rgb_mean (list[float]) – Image mean in RGB orders. Default: [0.485, 0.456, 0.406]

  • rgb_std (list[float]) – Image std in RGB orders. Default: [0.229, 0.224, 0.225]

  • flow_cfg (dict) – Config of SPyNet. Default: dict(norm_cfg=None, pretrained=None)

  • init_cfg (dict, optional) – Initialization config dict. Default: None.

forward(imgs)[源代码]
参数

imgs – Input frames with shape of (b, 2, 3, h, w).

返回

Interpolated frame with shape of (b, 3, h, w).

返回类型

Tensor

class mmedit.models.editors.tof.tof_vfi_net.BasicModule(norm_cfg)[源代码]

Bases: torch.nn.Module

Basic module of SPyNet.

Note that unlike the common spynet architecture, the basic module here could contain batch normalization.

参数

norm_cfg (dict | None) – Config of normalization.

forward(tensor_input)[源代码]
参数

tensor_input (Tensor) – Input tensor with shape (b, 8, h, w). 8 channels contain: [reference image (3), neighbor image (3), initial flow (2)].

返回

Estimated flow with shape (b, 2, h, w)

返回类型

Tensor

class mmedit.models.editors.tof.tof_vfi_net.SPyNet(norm_cfg, pretrained=None)[源代码]

Bases: torch.nn.Module

SPyNet architecture.

Note that this implementation is specifically for TOFlow. It differs from the common SPyNet in the following aspects:

  1. The basic modules in paper of TOFlow contain BatchNorm.

  2. Normalization and denormalization are not done here, as

    they are done in TOFlow.

Paper:

Optical Flow Estimation using a Spatial Pyramid Network

Code reference:

https://github.com/Coldog2333/pytoflow

参数
  • norm_cfg (dict | None) – Config of normalization.

  • pretrained (str) – path for pre-trained SPyNet. Default: None.

forward(ref, supp)[源代码]
参数
  • ref (Tensor) – Reference image with shape of (b, 3, h, w).

  • supp – The supporting image to be warped: (b, 3, h, w).

返回

Estimated optical flow: (b, 2, h, w).

返回类型

Tensor

class mmedit.models.editors.tof.tof_vfi_net.ToFResBlock[源代码]

Bases: torch.nn.Module

ResNet architecture.

Three-layers ResNet/ResBlock

forward(frames)[源代码]
参数

frames (Tensor) – Tensor with shape of (b, 2, 3, h, w).

返回

Interpolated frame with shape of (b, 3, h, w).

返回类型

Tensor

Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.