mmedit.models.losses.feature_loss 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Optional

import torch
import torch.nn as nn
from mmengine import MMLogger
from mmengine.runner import load_checkpoint

from mmedit.models.editors.dic import LightCNN
from mmedit.registry import MODELS

[文档]class LightCNNFeature(nn.Module): """Feature of LightCNN. It is used to train DICGAN. """ def __init__(self) -> None: super().__init__() model = LightCNN(3) self.features = nn.Sequential(*list(model.features.children())) self.features.requires_grad_(False)
[文档] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward function. Args: x (Tensor): Input tensor. Returns: Tensor: Forward results. """ return self.features(x)
[文档] def init_weights(self, pretrained: Optional[str] = None, strict: bool = True) -> None: """Init weights for models. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None. strict (boo, optional): Whether strictly load the pretrained model. Defaults to True. """ if isinstance(pretrained, str): logger = MMLogger.get_current_instance() load_checkpoint(self, pretrained, strict=strict, logger=logger) elif pretrained is not None: raise TypeError(f'"pretrained" must be a str or None. '
f'But received {type(pretrained)}.') @MODELS.register_module()
[文档]class LightCNNFeatureLoss(nn.Module): """Feature loss of DICGAN, based on LightCNN. Args: pretrained (str): Path for pretrained weights. loss_weight (float): Loss weight. Default: 1.0. criterion (str): Criterion type. Options are 'l1' and 'mse'. Default: 'l1'. """ def __init__(self, pretrained: str, loss_weight: float = 1.0, criterion: str = 'l1') -> None: super().__init__() self.model = LightCNNFeature() if not isinstance(pretrained, str): warnings.warn('`LightCNNFeature` model in FeatureLoss ' + 'should be pretrained') self.model.init_weights(pretrained) self.model.eval() self.loss_weight = loss_weight if criterion == 'l1': self.criterion = torch.nn.L1Loss() elif criterion == 'mse': self.criterion = torch.nn.MSELoss() else: raise ValueError("'criterion' should be 'l1' or 'mse', " f'but got {criterion}')
[文档] def forward(self, pred: torch.Tensor, gt: torch.Tensor) -> torch.Tensor: """Forward function. Args: pred (Tensor): Predicted tensor. gt (Tensor): GT tensor. Returns: Tensor: Forward results. """ self.model.eval() pred_feature = self.model(pred) gt_feature = self.model(gt).detach() feature_loss = self.criterion(pred_feature, gt_feature) return feature_loss * self.loss_weight
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.