Shortcuts

mmedit.models.editors.inst_colorization.inst_colorization

Module Contents

Classes

InstColorization

Colorization InstColorization method.

class mmedit.models.editors.inst_colorization.inst_colorization.InstColorization(data_preprocessor: Union[dict, mmengine.config.Config], image_model, instance_model, fusion_model, color_data_opt, which_direction='AtoB', loss=None, init_cfg=None, train_cfg=None, test_cfg=None)[source]

Bases: mmengine.model.BaseModel

Colorization InstColorization method.

This Colorization is implemented according to the paper:

Instance-aware Image Colorization, CVPR 2020

Adapted from ‘https://github.com/ericsujw/InstColorization.git’ ‘InstColorization/models/train_model’ Copyright (c) 2020, Su, under MIT License.

Parameters
  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor.

  • image_model (dict) – Config for single image model

  • instance_model (dict) – Config for instance model

  • fusion_model (dict) – Config for fusion model

  • color_data_opt (dict) – Option for colorspace conversion

  • which_direction (str) – AtoB or BtoA

  • loss (dict) – Config for loss.

  • init_cfg (str) – Initialization config dict. Default: None.

  • train_cfg (dict) – Config for training. Default: None.

  • test_cfg (dict) – Config for testing. Default: None.

forward(inputs: torch.Tensor, data_samples: Optional[List[mmedit.structures.EditDataSample]] = None, mode: str = 'tensor', **kwargs)[source]

Returns losses or predictions of training, validation, testing, and simple inference process.

forward method of BaseModel is an abstract method, its subclasses must implement this method.

Accepts inputs and data_samples processed by data_preprocessor, and returns results according to mode arguments.

During non-distributed training, validation, and testing process, forward will be called by BaseModel.train_step, BaseModel.val_step and BaseModel.val_step directly.

During distributed data parallel training process, MMSeparateDistributedDataParallel.train_step will first call DistributedDataParallel.forward to enable automatic gradient synchronization, and then call forward to get training loss.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

  • mode (str) –

    mode should be one of loss, predict and tensor. Default: ‘tensor’.

    • loss: Called by train_step and return loss dict used for logging

    • predict: Called by val_step and test_step and return list of BaseDataElement results used for computing metric.

    • tensor: Called by custom use to get Tensor type results.

Returns

  • If mode == loss, return a dict of loss tensor used for backward and logging.

  • If mode == predict, return a list of BaseDataElement for computing metric and getting inference result.

  • If mode == tensor, return a tensor or tuple of tensor or dict or tensor for custom use.

Return type

ForwardResults

convert_to_datasample(inputs, data_samples)[source]
abstract forward_train(inputs, data_samples=None, **kwargs)[source]

Forward function for training.

abstract train_step(data: List[dict], optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor][source]

Train step function.

Parameters
  • data (List[dict]) – Batch of data as input.

  • optim_wrapper (dict[torch.optim.Optimizer]) – Dict with optimizers for generator and discriminator (if have).

Returns

Dict with loss, information for logger, the number of

samples and results for visualization.

Return type

dict

forward_inference(inputs, data_samples=None, **kwargs)[source]

Forward inference. Returns predictions of validation, testing.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

Returns

predictions.

Return type

List[EditDataSample]

forward_tensor(inputs, data_samples)[source]

Forward function in tensor mode.

Parameters
  • inputs (torch.Tensor) – Input tensor.

  • data_sample (dict) – Dict contains data sample.

Returns

Dict contains output results.

Return type

dict

Read the Docs v: latest
Versions
master
latest
stable
zyh-re-docs
zyh-doc-notfound-extend
zyh-api-rendering
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.