Shortcuts

mmedit.models.base_models.base_edit_model 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union

import torch
from mmengine.model import BaseModel

from mmedit.registry import MODELS
from mmedit.structures import EditDataSample, PixelData


@MODELS.register_module()
[文档]class BaseEditModel(BaseModel): """Base model for image and video editing. It must contain a generator that takes frames as inputs and outputs an interpolated frame. It also has a pixel-wise loss for training. Args: generator (dict): Config for the generator structure. pixel_loss (dict): Config for pixel-wise loss. train_cfg (dict): Config for training. Default: None. test_cfg (dict): Config for testing. Default: None. init_cfg (dict, optional): The weight initialized config for :class:`BaseModule`. data_preprocessor (dict, optional): The pre-process config of :class:`BaseDataPreprocessor`. Attributes: init_cfg (dict, optional): Initialization config dict. data_preprocessor (:obj:`BaseDataPreprocessor`): Used for pre-processing data sampled by dataloader to the format accepted by :meth:`forward`. Default: None. """ def __init__(self, generator: dict, pixel_loss: dict, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, init_cfg: Optional[dict] = None, data_preprocessor: Optional[dict] = None): super().__init__( init_cfg=init_cfg, data_preprocessor=data_preprocessor) self.train_cfg = train_cfg self.test_cfg = test_cfg # generator self.generator = MODELS.build(generator) # loss self.pixel_loss = MODELS.build(pixel_loss)
[文档] def forward(self, inputs: torch.Tensor, data_samples: Optional[List[EditDataSample]] = None, mode: str = 'tensor', **kwargs) -> Union[torch.Tensor, List[EditDataSample], dict]: """Returns losses or predictions of training, validation, testing, and simple inference process. ``forward`` method of BaseModel is an abstract method, its subclasses must implement this method. Accepts ``inputs`` and ``data_samples`` processed by :attr:`data_preprocessor`, and returns results according to mode arguments. During non-distributed training, validation, and testing process, ``forward`` will be called by ``BaseModel.train_step``, ``BaseModel.val_step`` and ``BaseModel.val_step`` directly. During distributed data parallel training process, ``MMSeparateDistributedDataParallel.train_step`` will first call ``DistributedDataParallel.forward`` to enable automatic gradient synchronization, and then call ``forward`` to get training loss. Args: inputs (torch.Tensor): batch input tensor collated by :attr:`data_preprocessor`. data_samples (List[BaseDataElement], optional): data samples collated by :attr:`data_preprocessor`. mode (str): mode should be one of ``loss``, ``predict`` and ``tensor``. Default: 'tensor'. - ``loss``: Called by ``train_step`` and return loss ``dict`` used for logging - ``predict``: Called by ``val_step`` and ``test_step`` and return list of ``BaseDataElement`` results used for computing metric. - ``tensor``: Called by custom use to get ``Tensor`` type results. Returns: ForwardResults: - If ``mode == loss``, return a ``dict`` of loss tensor used for backward and logging. - If ``mode == predict``, return a ``list`` of :obj:`BaseDataElement` for computing metric and getting inference result. - If ``mode == tensor``, return a tensor or ``tuple`` of tensor or ``dict`` or tensor for custom use. """ if mode == 'tensor': return self.forward_tensor(inputs, data_samples, **kwargs) elif mode == 'predict': predictions = self.forward_inference(inputs, data_samples, **kwargs) predictions = self.convert_to_datasample(data_samples, predictions) return predictions elif mode == 'loss': return self.forward_train(inputs, data_samples, **kwargs)
[文档] def convert_to_datasample(self, inputs: List[EditDataSample], data_samples: List[EditDataSample] ) -> List[EditDataSample]: for data_sample, output in zip(inputs, data_samples): data_sample.output = output return inputs
[文档] def forward_tensor(self, inputs: torch.Tensor, data_samples: Optional[List[EditDataSample]] = None, **kwargs) -> torch.Tensor: """Forward tensor. Returns result of simple forward. Args: inputs (torch.Tensor): batch input tensor collated by :attr:`data_preprocessor`. data_samples (List[BaseDataElement], optional): data samples collated by :attr:`data_preprocessor`. Returns: Tensor: result of simple forward. """ feats = self.generator(inputs, **kwargs) return feats
[文档] def forward_inference(self, inputs: torch.Tensor, data_samples: Optional[List[EditDataSample]] = None, **kwargs) -> List[EditDataSample]: """Forward inference. Returns predictions of validation, testing, and simple inference. Args: inputs (torch.Tensor): batch input tensor collated by :attr:`data_preprocessor`. data_samples (List[BaseDataElement], optional): data samples collated by :attr:`data_preprocessor`. Returns: List[EditDataSample]: predictions. """ feats = self.forward_tensor(inputs, data_samples, **kwargs) feats = self.data_preprocessor.destructor(feats) predictions = [] for idx in range(feats.shape[0]): predictions.append( EditDataSample( pred_img=PixelData(data=feats[idx].to('cpu')), metainfo=data_samples[idx].metainfo)) return predictions
[文档] def forward_train(self, inputs: torch.Tensor, data_samples: Optional[List[EditDataSample]] = None, **kwargs) -> Dict[str, torch.Tensor]: """Forward training. Returns dict of losses of training. Args: inputs (torch.Tensor): batch input tensor collated by :attr:`data_preprocessor`. data_samples (List[BaseDataElement], optional): data samples collated by :attr:`data_preprocessor`. Returns: dict: Dict of losses. """ feats = self.forward_tensor(inputs, data_samples, **kwargs) gt_imgs = [data_sample.gt_img.data for data_sample in data_samples] batch_gt_data = torch.stack(gt_imgs) loss = self.pixel_loss(feats, batch_gt_data) return dict(loss=loss)
Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.