mmedit.models.losses.composition_loss
¶
Module Contents¶
Classes¶
L1 composition loss. 

MSE (L2) composition loss. 

Charbonnier composition loss. 
Attributes¶
 class mmedit.models.losses.composition_loss.L1CompositionLoss(loss_weight: float = 1.0, reduction: str = 'mean', sample_wise: bool = False)[source]¶
Bases:
torch.nn.Module
L1 composition loss.
 Parameters
loss_weight (float) – Loss weight for L1 loss. Default: 1.0.
reduction (str) – Specifies the reduction to apply to the output. Supported choices are ‘none’  ‘mean’  ‘sum’. Default: ‘mean’.
sample_wise (bool) – Whether calculate the loss samplewise. This argument only takes effect when reduction is ‘mean’ and weight (argument of forward()) is not None. It will first reduces loss with ‘mean’ persample, and then it means over all the samples. Default: False.
 forward(pred_alpha: torch.Tensor, fg: torch.Tensor, bg: torch.Tensor, ori_merged: torch.Tensor, weight: Optional[torch.Tensor] = None, **kwargs) torch.Tensor [source]¶
 Parameters
pred_alpha (Tensor) – of shape (N, 1, H, W). Predicted alpha matte.
fg (Tensor) – of shape (N, 3, H, W). Tensor of foreground object.
bg (Tensor) – of shape (N, 3, H, W). Tensor of background object.
ori_merged (Tensor) – of shape (N, 3, H, W). Tensor of origin merged image before normalized by ImageNet mean and std.
weight (Tensor, optional) – of shape (N, 1, H, W). It is an indicating matrix: weight[trimap == 128] = 1. Default: None.
 class mmedit.models.losses.composition_loss.MSECompositionLoss(loss_weight: float = 1.0, reduction: str = 'mean', sample_wise: bool = False)[source]¶
Bases:
torch.nn.Module
MSE (L2) composition loss.
 Parameters
loss_weight (float) – Loss weight for MSE loss. Default: 1.0.
reduction (str) – Specifies the reduction to apply to the output. Supported choices are ‘none’  ‘mean’  ‘sum’. Default: ‘mean’.
sample_wise (bool) – Whether calculate the loss samplewise. This argument only takes effect when reduction is ‘mean’ and weight (argument of forward()) is not None. It will first reduces loss with ‘mean’ persample, and then it means over all the samples. Default: False.
 forward(pred_alpha: torch.Tensor, fg: torch.Tensor, bg: torch.Tensor, ori_merged: torch.Tensor, weight: Optional[torch.Tensor] = None, **kwargs) torch.Tensor [source]¶
 Parameters
pred_alpha (Tensor) – of shape (N, 1, H, W). Predicted alpha matte.
fg (Tensor) – of shape (N, 3, H, W). Tensor of foreground object.
bg (Tensor) – of shape (N, 3, H, W). Tensor of background object.
ori_merged (Tensor) – of shape (N, 3, H, W). Tensor of origin merged image before normalized by ImageNet mean and std.
weight (Tensor, optional) – of shape (N, 1, H, W). It is an indicating matrix: weight[trimap == 128] = 1. Default: None.
 class mmedit.models.losses.composition_loss.CharbonnierCompLoss(loss_weight: float = 1.0, reduction: str = 'mean', sample_wise: bool = False, eps: bool = 1e12)[source]¶
Bases:
torch.nn.Module
Charbonnier composition loss.
 Parameters
loss_weight (float) – Loss weight for L1 loss. Default: 1.0.
reduction (str) – Specifies the reduction to apply to the output. Supported choices are ‘none’  ‘mean’  ‘sum’. Default: ‘mean’.
sample_wise (bool) – Whether calculate the loss samplewise. This argument only takes effect when reduction is ‘mean’ and weight (argument of forward()) is not None. It will first reduces loss with ‘mean’ persample, and then it means over all the samples. Default: False.
eps (float) – A value used to control the curvature near zero. Default: 1e12.
 forward(pred_alpha: torch.Tensor, fg: torch.Tensor, bg: torch.Tensor, ori_merged: torch.Tensor, weight: Optional[torch.Tensor] = None, **kwargs) torch.Tensor [source]¶
 Parameters
pred_alpha (Tensor) – of shape (N, 1, H, W). Predicted alpha matte.
fg (Tensor) – of shape (N, 3, H, W). Tensor of foreground object.
bg (Tensor) – of shape (N, 3, H, W). Tensor of background object.
ori_merged (Tensor) – of shape (N, 3, H, W). Tensor of origin merged image before normalized by ImageNet mean and std.
weight (Tensor, optional) – of shape (N, 1, H, W). It is an indicating matrix: weight[trimap == 128] = 1. Default: None.