Shortcuts

mmedit.models.losses.gradient_loss 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from mmedit.registry import MODELS
from .pixelwise_loss import l1_loss

[文档]_reduction_modes = ['none', 'mean', 'sum']
@MODELS.register_module()
[文档]class GradientLoss(nn.Module): """Gradient loss. Args: loss_weight (float): Loss weight for L1 loss. Default: 1.0. reduction (str): Specifies the reduction to apply to the output. Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. """ def __init__(self, loss_weight: float = 1.0, reduction: str = 'mean') -> None: super().__init__() self.loss_weight = loss_weight self.reduction = reduction if self.reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {self.reduction}. ' f'Supported ones are: {_reduction_modes}')
[文档] def forward(self, pred: torch.Tensor, target: torch.Tensor, weight: Optional[torch.Tensor] = None) -> torch.Tensor: """ Args: pred (Tensor): of shape (N, C, H, W). Predicted tensor. target (Tensor): of shape (N, C, H, W). Ground truth tensor. weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. """ kx = torch.Tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view(1, 1, 3, 3).to(target) ky = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).view(1, 1, 3, 3).to(target) pred_grad_x = F.conv2d(pred, kx, padding=1) pred_grad_y = F.conv2d(pred, ky, padding=1) target_grad_x = F.conv2d(target, kx, padding=1) target_grad_y = F.conv2d(target, ky, padding=1) loss = ( l1_loss( pred_grad_x, target_grad_x, weight, reduction=self.reduction) + l1_loss( pred_grad_y, target_grad_y, weight, reduction=self.reduction)) return loss * self.loss_weight
Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.