Shortcuts

Source code for mmedit.datasets.basic_image_dataset

# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import re
from typing import Callable, List, Optional, Tuple, Union

from mmengine.dataset import BaseDataset
from mmengine.fileio import get_file_backend, list_from_file

from mmedit.registry import DATASETS

[docs]IMG_EXTENSIONS = ('.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif', '.TIF', '.tiff', '.TIFF')
@DATASETS.register_module()
[docs]class BasicImageDataset(BaseDataset): """BasicImageDataset for open source projects in OpenMMLab/MMEditing. This dataset is designed for low-level vision tasks with image, such as super-resolution and inpainting. The annotation file is optional. If use annotation file, the annotation format can be shown as follows. .. code-block:: none Case 1 (CelebA-HQ): 000001.png 000002.png Case 2 (DIV2K): 0001_s001.png (480,480,3) 0001_s002.png (480,480,3) 0001_s003.png (480,480,3) 0002_s001.png (480,480,3) 0002_s002.png (480,480,3) Case 3 (Vimeo90k): 00001/0266 (256, 448, 3) 00001/0268 (256, 448, 3) Args: ann_file (str): Annotation file path. Defaults to ''. metainfo (dict, optional): Meta information for dataset, such as class information. Defaults to None. data_root (str, optional): The root directory for ``data_prefix`` and ``ann_file``. Defaults to None. data_prefix (dict, optional): Prefix for training data. Defaults to dict(img=None, ann=None). pipeline (list, optional): Processing pipeline. Defaults to []. test_mode (bool, optional): ``test_mode=True`` means in test phase. Defaults to False. filename_tmpl (dict): Template for each filename. Note that the template excludes the file extension. Default: dict(). search_key (str): The key used for searching the folder to get data_list. Default: 'gt'. backend_args (dict, optional): Arguments to instantiate the preifx of uri corresponding backend. Defaults to None. suffix (str or tuple[str], optional): File suffix that we are interested in. Default: None. recursive (bool): If set to True, recursively scan the directory. Default: False. Note: Assume the file structure as the following: .. code-block:: none mmediting (root) ├── mmedit ├── tools ├── configs ├── data │ ├── DIV2K │ │ ├── DIV2K_train_HR │ │ │ ├── image.png │ │ ├── DIV2K_train_LR_bicubic │ │ │ ├── X2 │ │ │ ├── X3 │ │ │ ├── X4 │ │ │ │ ├── image_x4.png │ │ ├── DIV2K_valid_HR │ │ ├── DIV2K_valid_LR_bicubic │ │ │ ├── X2 │ │ │ ├── X3 │ │ │ ├── X4 │ ├── places │ │ ├── test_set │ │ ├── train_set | | ├── meta | | | ├── Places365_train.txt | | | ├── Places365_val.txt Examples: Case 1: Loading DIV2K dataset for training a SISR model. .. code-block:: python dataset = BasicImageDataset( ann_file='', metainfo=dict( dataset_type='div2k', task_name='sisr'), data_root='data/DIV2K', data_prefix=dict( gt='DIV2K_train_HR', img='DIV2K_train_LR_bicubic/X4'), filename_tmpl=dict(img='{}_x4', gt='{}'), pipeline=[]) Case 2: Loading places dataset for training an inpainting model. .. code-block:: python dataset = BasicImageDataset( ann_file='meta/Places365_train.txt', metainfo=dict( dataset_type='places365', task_name='inpainting'), data_root='data/places', data_prefix=dict(gt='train_set'), pipeline=[]) """
[docs] METAINFO = dict(dataset_type='basic_image_dataset', task_name='editing')
def __init__(self, ann_file: str = '', metainfo: Optional[dict] = None, data_root: Optional[str] = None, data_prefix: dict = dict(img=''), pipeline: List[Union[dict, Callable]] = [], test_mode: bool = False, filename_tmpl: dict = dict(), search_key: Optional[str] = None, backend_args: Optional[dict] = None, img_suffix: Optional[Union[str, Tuple[str]]] = IMG_EXTENSIONS, recursive: bool = False, **kwards): for key in data_prefix: if key not in filename_tmpl: filename_tmpl[key] = '{}' if search_key is None: keys = list(data_prefix.keys()) search_key = keys[0] self.search_key = search_key self.filename_tmpl = filename_tmpl self.use_ann_file = (ann_file != '') if backend_args is None: self.backend_args = None else: self.backend_args = backend_args.copy() self.img_suffix = img_suffix self.recursive = recursive self.file_backend = get_file_backend( uri=data_root, backend_args=backend_args) super().__init__( ann_file=ann_file, metainfo=metainfo, data_root=data_root, data_prefix=data_prefix, pipeline=pipeline, test_mode=test_mode, **kwards)
[docs] def load_data_list(self) -> List[dict]: """Load data list from folder or annotation file. Returns: list[dict]: A list of annotation. """ path_list = self._get_path_list() data_list = [] for file in path_list: basename, ext = osp.splitext(file) if basename.startswith(os.sep): # Avoid absolute-path-like annotations basename = basename[1:] data = dict(key=basename) for key in self.data_prefix: path = osp.join(self.data_prefix[key], (f'{self.filename_tmpl[key].format(basename)}' f'{ext}')) data[f'{key}_path'] = path data_list.append(data) return data_list
[docs] def _get_path_list(self): """Get list of paths from annotation file or folder of dataset. Returns: list[dict]: A list of paths. """ path_list = [] if self.use_ann_file: path_list = self._get_path_list_from_ann() else: path_list = self._get_path_list_from_folder() return path_list
[docs] def _get_path_list_from_ann(self): """Get list of paths from annotation file. Returns: List: List of paths. """ ann_list = list_from_file( self.ann_file, backend_args=self.backend_args) path_list = [] for ann in ann_list: if ann.isspace() or ann == '': continue path = ann.split(' ')[0] # Compatible with Windows file systems path = path.replace('/', os.sep) path_list.append(path) return path_list
[docs] def _get_path_list_from_folder(self): """Get list of paths from folder. Returns: List: List of paths. """ path_list = [] folder = self.data_prefix[self.search_key] tmpl = self.filename_tmpl[self.search_key].format('') virtual_path = self.filename_tmpl[self.search_key].format('.*') for img_path in self.file_backend.list_dir_or_file( dir_path=folder, list_dir=False, suffix=self.img_suffix, recursive=self.recursive, ): basename, ext = osp.splitext(img_path) if re.match(virtual_path, basename): img_path = img_path.replace(tmpl + ext, ext) path_list.append(img_path) return path_list
Read the Docs v: latest
Versions
master
latest
stable
zyh-re-docs
zyh-doc-notfound-extend
zyh-api-rendering
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.