Shortcuts

Source code for mmedit.engine.hooks.visualization_hook

# Copyright (c) OpenMMLab. All rights reserved.
import math
import warnings
from collections import defaultdict
from copy import deepcopy
from typing import Dict, List, Optional, Sequence, Tuple, Union

import torch
from mmengine import MessageHub
from mmengine.dist import master_only
from mmengine.hooks import Hook
from mmengine.registry import HOOKS
from mmengine.runner import Runner
from mmengine.structures import BaseDataElement
from mmengine.utils import is_list_of
from mmengine.visualization import Visualizer

from mmedit.structures import EditDataSample, PixelData
from mmedit.utils import get_sampler


@HOOKS.register_module()
[docs]class BasicVisualizationHook(Hook): """Basic hook that invoke visualizers during validation and test. Args: interval (int | dict): Visualization interval. Default: {}. on_train (bool): Whether to call hook during train. Default to False. on_val (bool): Whether to call hook during validation. Default to True. on_test (bool): Whether to call hook during test. Default to True. """
[docs] priority = 'NORMAL'
def __init__(self, interval: dict = {}, on_train=False, on_val=True, on_test=True): self._interval = interval self._sample_counter = 0 self._vis_dir = None self._on_train = on_train self._on_val = on_val self._on_test = on_test
[docs] def _after_iter( self, runner, batch_idx: int, data_batch: Optional[Sequence[dict]], outputs: Optional[Sequence[BaseDataElement]], mode=None, ) -> None: """Show or Write the predicted results. Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the test loop. data_batch (Sequence[dict], optional): Data from dataloader. Defaults to None. outputs (Sequence[BaseDataElement], optional): Outputs from model. Defaults to None. """ if mode == 'train' and (not self._on_train): return elif mode == 'val' and (not self._on_val): return elif mode == 'test' and (not self._on_test): return if isinstance(self._interval, int): interval = self._interval else: interval = self._interval.get(mode, 1) if self.every_n_inner_iters(batch_idx, interval): for data_sample in outputs: runner.visualizer.add_datasample(data_sample, step=runner.iter)
@HOOKS.register_module()
[docs]class GenVisualizationHook(Hook): """Generation Visualization Hook. Used to visual output samples in training, validation and testing. In this hook, we use a list called `sample_kwargs_list` to control how to generate samples and how to visualize them. Each element in `sample_kwargs_list`, called `sample_kwargs`, may contains the following keywords: - Required key words: - 'type': Value must be string. Denotes what kind of sampler is used to generate image. Refers to :meth:`~mmedit.utils.get_sampler`. - Optional key words (If not passed, will use the default value): - 'n_rows': Value must be int. The number of images in one row. - 'num_samples': Value must be int. The number of samples to visualize. - 'vis_mode': Value must be string. How to visualize the generated samples (e.g. image, gif). - 'fixed_input': Value must be bool. Whether use the fixed input during the loop. - 'draw_gt': Value must be bool. Whether save the real images. - 'target_keys': Value must be string or list of string. The keys of the target image to visualize. - 'name': Value must be string. If not passed, will use `sample_kwargs['type']` as default. For convenience, we also define a group of alias of samplers' type for models supported in MMEditing. Refers to `:attr:self.SAMPLER_TYPE_MAPPING`. Example: >>> # for GAN models >>> custom_hooks = [ >>> dict( >>> type='GenVisualizationHook', >>> interval=1000, >>> fixed_input=True, >>> vis_kwargs_list=dict(type='GAN', name='fake_img'))] >>> # for Translation models >>> custom_hooks = [ >>> dict( >>> type='GenVisualizationHook', >>> interval=10, >>> fixed_input=False, >>> vis_kwargs_list=[dict(type='Translation', >>> name='translation_train', >>> n_samples=6, draw_gt=True, >>> n_rows=3), >>> dict(type='TranslationVal', >>> name='translation_val', >>> n_samples=16, draw_gt=True, >>> n_rows=4)])] # NOTE: user-defined vis_kwargs > vis_kwargs_mapping > hook init args Args: interval (int): Visualization interval. Default: 1000. sampler_kwargs_list (Tuple[List[dict], dict]): The list of sampling behavior to generate images. fixed_input (bool): The default action of whether use fixed input to generate samples during the loop. Defaults to True. n_samples (Optional[int]): The default value of number of samples to visualize. Defaults to 64. n_row (Optional[int]): The default value of number of images in each row in the visualization results. Defaults to 8. message_hub_vis_kwargs (Optional[Tuple[str, dict, List[str], List[Dict]]]): Key arguments visualize images in message hub. Defaults to None. save_at_test (bool): Whether save images during test. Defaults to True. max_save_at_test (int): Maximum number of samples saved at test time. If None is passed, all samples will be saved. Defaults to 100. show (bool): Whether to display the drawn image. Default to False. wait_time (float): The interval of show (s). Defaults to 0. """
[docs] priority = 'NORMAL'
[docs] VIS_KWARGS_MAPPING = dict( GAN=dict(type='Noise'), SinGAN=dict(type='Arguments', forward_kwargs=dict(mode='rand')), Translation=dict(type='Data'), TranslationVal=dict(type='ValData'), TranslationTest=dict(type='TestData'), DDPMDenoising=dict( type='Arguments', name='ddpm_sample', n_samples=16, n_row=4, vis_mode='gif', n_skip=100, forward_kwargs=dict( forward_mode='sampling', sample_kwargs=dict(show_pbar=True, save_intermedia=True))))
def __init__(self, interval: int = 1000, vis_kwargs_list: Tuple[List[dict], dict] = None, fixed_input: bool = True, n_samples: Optional[int] = 64, n_row: Optional[int] = 8, message_hub_vis_kwargs: Optional[Tuple[str, dict, List[str], List[Dict]]] = None, save_at_test: bool = True, max_save_at_test: int = 100, test_vis_keys: Optional[Union[str, List[str]]] = None, show: bool = False, wait_time: float = 0): self._visualizer: Visualizer = Visualizer.get_current_instance() self.interval = interval self.vis_kwargs_list = deepcopy(vis_kwargs_list) if isinstance(self.vis_kwargs_list, dict): self.vis_kwargs_list = [self.vis_kwargs_list] self.fixed_input = fixed_input self.inputs_buffer = defaultdict(list) self.n_samples = n_samples self.n_row = n_row self.show = show if self.show: # No need to think about vis backends. self._visualizer._vis_backends = {} warnings.warn('The show is True, it means that only ' 'the prediction results are visualized ' 'without storing data, so vis_backends ' 'needs to be excluded.') self.wait_time = wait_time self.save_at_test = save_at_test self.test_vis_keys_list = test_vis_keys self.max_save_at_test = max_save_at_test self.message_vis_kwargs = message_hub_vis_kwargs @master_only
[docs] def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, outputs) -> None: """:class:`GenVisualizationHook` do not support visualize during validation. Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the test loop. data_batch (Sequence[dict], optional): Data from dataloader. Defaults to None. outputs: outputs of the generation model """ return
@master_only
[docs] def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, outputs): """Visualize samples after test iteraiton. Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the test loop. data_batch (dict, optional): Data from dataloader. Defaults to None. outputs: outputs of the generation model Defaults to None. """ if not self.save_at_test: return # get color order, mean and std module = runner.model if hasattr(module, 'module'): module = module.module data_preprocessor = module.data_preprocessor if hasattr(data_preprocessor, 'output_color_order'): output_color_order = data_preprocessor.output_color_order else: output_color_order = 'bgr' mean = data_preprocessor.mean std = data_preprocessor.std for idx, sample in enumerate(outputs): curr_idx = batch_idx * len(outputs) + idx if (self.max_save_at_test is not None and curr_idx >= self.max_save_at_test): continue if self.test_vis_keys_list is None: target_keys = [ k for k, v in sample.items() if not k.startswith('_') and isinstance(v, PixelData) ] assert len(target_keys), ( 'Cannot found PixelData in outputs. Please specific ' '\'vis_test_keys_list\'.') elif isinstance(self.test_vis_keys_list, str): target_keys = [self.test_vis_keys_list] else: assert is_list_of(self.test_vis_keys_list, str), ( 'test_vis_keys_list must be str or list of str or None.') target_keys = self.test_vis_keys_list for key in target_keys: name = key.replace('.', '_') self._visualizer.add_datasample( name=f'test_{name}', gen_samples=[sample], step=curr_idx, target_keys=key, n_row=1, color_order=output_color_order, target_mean=mean.cpu(), target_std=std.cpu())
@master_only
[docs] def after_train_iter(self, runner: Runner, batch_idx: int, data_batch: dict = None, outputs: Optional[dict] = None) -> None: """Visualize samples after train iteration. Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. data_batch (dict): Data from dataloader. Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. """ if self.every_n_inner_iters(batch_idx, self.interval): self.vis_sample(runner, batch_idx, data_batch, outputs)
@torch.no_grad()
[docs] def vis_sample(self, runner: Runner, batch_idx: int, data_batch: dict, outputs: Optional[dict] = None) -> None: """Visualize samples. Args: runner (Runner): The runner conatians model to visualize. batch_idx (int): The index of the current batch in loop. data_batch (dict): Data from dataloader. Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. """ # this function will only called in training process num_batches = runner.train_dataloader.batch_size module = runner.model module.eval() if hasattr(module, 'module'): module = module.module forward_func = module.val_step # get color order, mean and std data_preprocessor = module.data_preprocessor if hasattr(data_preprocessor, 'output_color_order'): output_color_order = data_preprocessor.output_color_order else: output_color_order = 'bgr' mean = data_preprocessor.mean std = data_preprocessor.std for vis_kwargs in self.vis_kwargs_list: # pop the sample-unrelated values vis_kwargs_ = deepcopy(vis_kwargs) sampler_type = vis_kwargs_['type'] # replace with alias for alias in self.VIS_KWARGS_MAPPING.keys(): if alias.upper() == sampler_type.upper(): sampler_alias = deepcopy(self.VIS_KWARGS_MAPPING[alias]) vis_kwargs_['type'] = sampler_alias.pop('type') for default_k, default_v in sampler_alias.items(): vis_kwargs_.setdefault(default_k, default_v) break # sampler_type = vis_kwargs_.pop('type') name = vis_kwargs_.pop('name', None) if not name: name = sampler_type.lower() n_samples = vis_kwargs_.pop('n_samples', self.n_samples) n_row = vis_kwargs_.pop('n_row', self.n_row) n_row = min(n_row, n_samples) num_iters = math.ceil(n_samples / num_batches) vis_kwargs_['max_times'] = num_iters vis_kwargs_['num_batches'] = num_batches fixed_input = vis_kwargs_.pop('fixed_input', self.fixed_input) target_keys = vis_kwargs_.pop('target_keys', None) vis_mode = vis_kwargs_.pop('vis_mode', None) output_list = [] if fixed_input and self.inputs_buffer[sampler_type]: sampler = self.inputs_buffer[sampler_type] else: sampler = get_sampler(vis_kwargs_, runner) need_save = fixed_input and not self.inputs_buffer[sampler_type] for inputs in sampler: output_list += [out for out in forward_func(inputs)] # save inputs if need_save: self.inputs_buffer[sampler_type].append(inputs) self._visualizer.add_datasample( name=name, gen_samples=output_list[:n_samples], target_keys=target_keys, vis_mode=vis_mode, n_row=n_row, color_order=output_color_order, target_mean=mean.cpu(), target_std=std.cpu(), show=self.show, wait_time=self.wait_time, step=batch_idx + 1, **vis_kwargs_) # save images in message_hub self.vis_from_message_hub(batch_idx, output_color_order, mean, std) module.train()
[docs] def vis_from_message_hub(self, batch_idx: int, color_order: str, target_mean: Sequence[Union[float, int]], target_std: Sequence[Union[float, int]]): """Visualize samples from message hub. Args: batch_idx (int): The index of the current batch in the test loop. color_order (str): The color order of generated images. target_mean (Sequence[Union[float, int]]): The original mean of the image tensor before preprocessing. Image will be re-shifted to ``target_mean`` before visualizing. target_std (Sequence[Union[float, int]]): The original std of the image tensor before preprocessing. Image will be re-scaled to ``target_std`` before visualizing. """ if self.message_vis_kwargs is None: return message_hub = MessageHub.get_current_instance() if 'vis_results' not in message_hub.runtime_info: raise RuntimeError('Cannot find \'vis_results\' in ' '\'message_hub.runtime_info\'. Cannot perform ' 'visualization from messageHub.') vis_results = message_hub.get_info('vis_results') if isinstance(self.message_vis_kwargs, str): target_keys, vis_modes = [self.message_vis_kwargs], [None] elif isinstance(self.message_vis_kwargs, dict): target_keys = [self.message_vis_kwargs['key']] vis_modes = [self.message_vis_kwargs['vis_mode']] elif is_list_of(self.message_vis_kwargs, str): target_keys = self.message_vis_kwargs vis_modes = [None for _ in range(len(target_keys))] else: # list of dict target_keys = [kwargs['key'] for kwargs in self.message_vis_kwargs] vis_modes = [ kwargs.pop('vis_mode', None) for kwargs in deepcopy(self.message_vis_kwargs) ] for key, vis_mode in zip(target_keys, vis_modes): if key not in vis_results: raise RuntimeError( f'Cannot find \'{key}\' in ' 'message_hub.runtime_info[\'vis_results\'].') value = vis_results[key] # pack to list of EditDataSample if isinstance(value, torch.Tensor): gen_samples = [] num_batches = value.shape[0] for idx in range(num_batches): gen_sample = EditDataSample() setattr(gen_sample, key, PixelData(data=value[idx])) gen_samples.append(gen_sample) elif is_list_of(value, BaseDataElement): # already packed gen_samples = value num_batches = len(gen_samples) else: raise TypeError( 'Only support to visualize Tensor or list of DataSample ' f'in MessageHub. But \'{key}\' is \'{type(value)}\'.') self._visualizer.add_datasample( name=f'train_{key}', gen_samples=gen_samples, target_keys=key, vis_mode=vis_mode, n_row=min(self.n_row, num_batches), color_order=color_order, target_mean=target_mean.cpu(), target_std=target_std.cpu(), show=self.show, step=batch_idx)
Read the Docs v: latest
Versions
master
latest
stable
zyh-re-docs
zyh-doc-notfound-extend
zyh-api-rendering
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.