Shortcuts

mmedit.models.editors.dim.dim

Module Contents

Classes

DIM

Deep Image Matting model.

class mmedit.models.editors.dim.dim.DIM(data_preprocessor, backbone, refiner=None, train_cfg=None, test_cfg=None, loss_alpha=None, loss_comp=None, loss_refine=None, init_cfg: Optional[dict] = None)[源代码]

Bases: mmedit.models.base_models.BaseMattor

Deep Image Matting model.

https://arxiv.org/abs/1703.03872

备注

For (self.train_cfg.train_backbone, self.train_cfg.train_refiner):

  • (True, False) corresponds to the encoder-decoder stage in the paper.

  • (False, True) corresponds to the refinement stage in the paper.

  • (True, True) corresponds to the fine-tune stage in the paper.

参数
  • data_preprocessor (dict, optional) – Config of data pre-processor.

  • backbone (dict) – Config of backbone.

  • refiner (dict) – Config of refiner.

  • loss_alpha (dict) – Config of the alpha prediction loss. Default: None.

  • loss_comp (dict) – Config of the composition loss. Default: None.

  • loss_refine (dict) – Config of the loss of the refiner. Default: None.

  • train_cfg (dict) – Config of training. In train_cfg, train_backbone should be specified. If the model has a refiner, train_refiner should be specified.

  • test_cfg (dict) – Config of testing. In test_cfg, If the model has a refiner, train_refiner should be specified.

  • init_cfg (dict, optional) – The weight initialized config for BaseModule. Default: None.

property with_refiner[源代码]

Whether the matting model has a refiner.

init_weights()[源代码]

Initialize the model network weights.

train(mode=True)[源代码]

Mode switcher.

参数

mode (bool) – whether to set training mode (True) or evaluation mode (False). Default: True.

freeze_backbone()[源代码]

Freeze the backbone and only train the refiner.

_forward(x: torch.Tensor, *, refine: bool = True) Tuple[torch.Tensor, torch.Tensor][源代码]

Raw forward function.

参数
  • x (torch.Tensor) – Concatenation of merged image and trimap with shape (N, 4, H, W)

  • refine (bool) – if forward through refiner

返回

pred_alpha, with shape (N, 1, H, W) torch.Tensor: pred_refine, with shape (N, 4, H, W)

返回类型

torch.Tensor

_forward_test(inputs)[源代码]

Forward to get alpha prediction.

_forward_train(inputs, data_samples)[源代码]

Defines the computation performed at every training call.

参数
  • inputs (torch.Tensor) – Concatenation of normalized image and trimap shape (N, 4, H, W)

  • data_samples (list[EditDataSample]) –

    Data samples containing: - gt_alpha (Tensor): Ground-truth of alpha

    shape (N, 1, H, W), normalized to 0 to 1.

    • gt_fg (Tensor): Ground-truth of foreground

      shape (N, C, H, W), normalized to 0 to 1.

    • gt_bg (Tensor): Ground-truth of background

      shape (N, C, H, W), normalized to 0 to 1.

返回

Contains the loss items and batch information.

返回类型

dict

Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.