Source code for mmedit.models.losses.loss_wrapper

# Copyright (c) OpenMMLab. All rights reserved.
import functools
from typing import Optional

import torch
import torch.nn.functional as F

[docs]def reduce_loss(loss: torch.Tensor, reduction: str) -> torch.Tensor: """Reduce loss as specified. Args: loss (Tensor): Elementwise loss tensor. reduction (str): Options are "none", "mean" and "sum". Returns: Tensor: Reduced loss tensor. """ reduction_enum = F._Reduction.get_enum(reduction) # none: 0, elementwise_mean:1, sum: 2 if reduction_enum == 0: return loss if reduction_enum == 1: return loss.mean() if reduction_enum == 2: return loss.sum() raise ValueError(f'reduction type {reduction} not supported')
[docs]def mask_reduce_loss(loss: torch.Tensor, weight: Optional[torch.Tensor] = None, reduction: str = 'mean', sample_wise: bool = False) -> torch.Tensor: """Apply element-wise weight and reduce loss. Args: loss (Tensor): Element-wise loss. weight (Tensor): Element-wise weights. Default: None. reduction (str): Same as built-in losses of PyTorch. Options are "none", "mean" and "sum". Default: 'mean'. sample_wise (bool): Whether calculate the loss sample-wise. This argument only takes effect when `reduction` is 'mean' and `weight` (argument of `forward()`) is not None. It will first reduces loss with 'mean' per-sample, and then it means over all the samples. Default: False. Returns: Tensor: Processed loss values. """ # if weight is specified, apply element-wise weight if weight is not None: assert weight.dim() == loss.dim() assert weight.size(1) == 1 or weight.size(1) == loss.size(1) loss = loss * weight # if weight is not specified or reduction is sum, just reduce the loss if weight is None or reduction == 'sum': loss = reduce_loss(loss, reduction) # if reduction is mean, then compute mean over masked region elif reduction == 'mean': # expand weight from N1HW to NCHW if weight.size(1) == 1: weight = weight.expand_as(loss) # small value to prevent division by zero eps = 1e-12 # perform sample-wise mean if sample_wise: weight = weight.sum(dim=[1, 2, 3], keepdim=True) # NCHW to N111 loss = (loss / (weight + eps)).sum() / weight.size(0) # perform pixel-wise mean else: loss = loss.sum() / (weight.sum() + eps) return loss
[docs]def masked_loss(loss_func): """Create a masked version of a given loss function. To use this decorator, the loss function must have the signature like `loss_func(pred, target, **kwargs)`. The function only needs to compute element-wise loss without any reduction. This decorator will add weight and reduction arguments to the function. The decorated function will have the signature like `loss_func(pred, target, weight=None, reduction='mean', avg_factor=None, **kwargs)`. :Example: >>> import torch >>> @masked_loss >>> def l1_loss(pred, target): >>> return (pred - target).abs() >>> pred = torch.Tensor([0, 2, 3]) >>> target = torch.Tensor([1, 1, 1]) >>> weight = torch.Tensor([1, 0, 1]) >>> l1_loss(pred, target) tensor(1.3333) >>> l1_loss(pred, target, weight) tensor(1.5000) >>> l1_loss(pred, target, reduction='none') tensor([1., 1., 2.]) >>> l1_loss(pred, target, weight, reduction='sum') tensor(3.) """ @functools.wraps(loss_func) def wrapper(pred: torch.Tensor, target: torch.Tensor, weight: Optional[torch.Tensor] = None, reduction: str = 'mean', sample_wise: bool = False, **kwargs) -> torch.Tensor: # get element-wise loss loss = loss_func(pred, target, **kwargs) loss = mask_reduce_loss(loss, weight, reduction, sample_wise) return loss return wrapper
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.