Shortcuts

Source code for mmedit.models.editors.arcface.id_loss

# Copyright (c) OpenMMLab. All rights reserved.
import mmengine
import torch
from torch import nn

from mmedit.registry import MODELS
from .model_irse import Backbone


@MODELS.register_module('ArcFace')
[docs]class IDLossModel(nn.Module): """Face id loss model. Args: ir_se50_weights (str, optional): Url of ir-se50 weights. Defaults to None. """ # ir se50 weight download link
[docs] _ir_se50_url = 'https://gg0ltg.by.files.1drv.com/y4m3fNNszG03z9n8JQ7EhdtQKW8tQVQMFBisPVRgoXi_UfP8pKSSqv8RJNmHy2JampcPmEazo_Mx6NTFSqBpZmhPniROm9uNoghnzaavvYpxkCfiNmDH9YyIF3g-0nwt6bsjk2X80JDdL5z88OAblSDmB-kuQkWSWvA9BM3Xt8DHMCY8lO4HOQCZ5YWUtFyPAVwEyzTGDM-JRA5EJoN2bF1cg' # noqa
def __init__(self, ir_se50_weights=None): super(IDLossModel, self).__init__() mmengine.print_log('Loading ResNet ArcFace', 'current') self.facenet = Backbone( input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') if ir_se50_weights is None: ir_se50_weights = self._ir_se50_url self.facenet.load_state_dict( torch.hub.load_state_dict_from_url( ir_se50_weights, map_location='cpu')) self.pool = torch.nn.AdaptiveAvgPool2d((256, 256)) self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) self.facenet = self.facenet.eval()
[docs] def extract_feats(self, x): """Extracting face features. Args: x (torch.Tensor): Image tensor of faces. Returns: torch.Tensor: Face features. """ if x.shape[2] != 256: x = self.pool(x) x = x[:, :, 35:223, 32:220] # Crop interesting region x = self.face_pool(x) x_feats = self.facenet(x) return x_feats
[docs] def forward(self, pred=None, gt=None): """Calculate face loss. Args: pred (torch.Tensor, optional): Predictions of face images. Defaults to None. gt (torch.Tensor, optional): Ground truth of face images. Defaults to None. Returns: Tuple(float, float): A tuple contain face similarity loss and improvement. """ n_samples = gt.shape[0] y_feats = self.extract_feats( gt) # Otherwise use the feature from there y_hat_feats = self.extract_feats(pred) y_feats = y_feats.detach() loss = 0 sim_improvement = 0 count = 0 for i in range(n_samples): diff_target = y_hat_feats[i].dot(y_feats[i]) loss += 1 - diff_target count += 1 return loss / count, sim_improvement / count
Read the Docs v: latest
Versions
master
latest
stable
zyh-re-docs
zyh-doc-notfound-extend
zyh-api-rendering
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.