Shortcuts

mmedit.models.base_models.base_translation_model

Module Contents

Classes

BaseTranslationModel

Base Translation Model.

class mmedit.models.base_models.base_translation_model.BaseTranslationModel(generator, discriminator, default_domain: str, reachable_domains: List[str], related_domains: List[str], data_preprocessor, discriminator_steps: int = 1, disc_init_steps: int = 0, real_img_key: str = 'real_img', loss_config: Optional[dict] = None)[源代码]

Bases: mmengine.model.BaseModel

Base Translation Model.

Translation models can transfer images from one domain to another. Domain information like default_domain, reachable_domains are needed to initialize the class. And we also provide query functions like is_domain_reachable, get_other_domains.

You can get a specific generator based on the domain, and by specifying target_domain in the forward function, you can decide the domain of generated images. Considering the difference among different image translation models, we only provide the external interfaces mentioned above. When you implement image translation with a specific method, you can inherit both BaseTranslationModel and the method (e.g BaseGAN) and implement abstract methods.

参数
  • default_domain (str) – Default output domain.

  • reachable_domains (list[str]) – Domains that can be generated by the model.

  • related_domains (list[str]) – Domains involved in training and testing. reachable_domains must be contained in related_domains. However, related_domains may contain source domains that are used to retrieve source images from data_batch but not in reachable_domains.

  • discriminator_steps (int) – The number of times the discriminator is completely updated before the generator is updated. Defaults to 1.

  • disc_init_steps (int) – The number of initial steps used only to train discriminators.

init_weights(pretrained=None)[源代码]

Initialize weights for the model.

参数

pretrained (str, optional) – Path for pretrained weights. If given None, pretrained weights will not be loaded. Default: None.

get_module(module)[源代码]

Get nn.ModuleDict to fit the MMDistributedDataParallel interface.

参数

module (MMDistributedDataParallel | nn.ModuleDict) – The input module that needs processing.

返回

The ModuleDict of multiple networks.

返回类型

nn.ModuleDict

forward(img, test_mode=False, **kwargs)[源代码]

Forward function.

参数
  • img (tensor) – Input image tensor.

  • test_mode (bool) – Whether in test mode or not. Default: False.

  • kwargs (dict) – Other arguments.

forward_train(img, target_domain, **kwargs)[源代码]

Forward function for training.

参数
  • img (tensor) – Input image tensor.

  • target_domain (str) – Target domain of output image.

  • kwargs (dict) – Other arguments.

返回

Forward results.

返回类型

dict

forward_test(img, target_domain, **kwargs)[源代码]

Forward function for testing.

参数
  • img (tensor) – Input image tensor.

  • target_domain (str) – Target domain of output image.

  • kwargs (dict) – Other arguments.

返回

Forward results.

返回类型

dict

is_domain_reachable(domain)[源代码]

Whether image of this domain can be generated.

get_other_domains(domain)[源代码]

get other domains.

_get_target_generator(domain)[源代码]

get target generator.

_get_target_discriminator(domain)[源代码]

get target discriminator.

translation(image, target_domain=None, **kwargs)[源代码]

Translation Image to target style.

参数
  • image (tensor) – Image tensor with a shape of (N, C, H, W).

  • target_domain (str, optional) – Target domain of output image. Default to None.

返回

Image tensor of target style.

返回类型

dict

Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.