Shortcuts

mmedit.models.editors.pix2pix

Package Contents

Classes

Pix2Pix

Pix2Pix model for paired image-to-image translation.

UnetGenerator

Construct the Unet-based generator from the innermost layer to the

class mmedit.models.editors.pix2pix.Pix2Pix(*args, **kwargs)[源代码]

Bases: mmedit.models.base_models.BaseTranslationModel

Pix2Pix model for paired image-to-image translation.

Ref:

Image-to-Image Translation with Conditional Adversarial Networks

forward_test(img, target_domain, **kwargs)

Forward function for testing.

参数
  • img (tensor) – Input image tensor.

  • target_domain (str) – Target domain of output image.

  • kwargs (dict) – Other arguments.

返回

Forward results.

返回类型

dict

_get_disc_loss(outputs)

Get the loss of discriminator.

参数

outputs (dict) – A dict of output.

返回

Loss and a dict of log of loss terms.

返回类型

Tuple

_get_gen_loss(outputs)

Get the loss of generator.

参数

outputs (dict) – A dict of output.

返回

Loss and a dict of log of loss terms.

返回类型

Tuple

train_step(data, optim_wrapper=None)

Training step function.

参数
  • data_batch (dict) – Dict of the input data batch.

  • optimizer (dict[torch.optim.Optimizer]) – Dict of optimizers for the generator and discriminator.

  • ddp_reducer (Reducer | None, optional) – Reducer from ddp. It is used to prepare for backward() in ddp. Defaults to None.

  • running_status (dict | None, optional) – Contains necessary basic information for training, e.g., iteration number. Defaults to None.

返回

Dict of loss, information for logger, the number of samples and results for visualization.

返回类型

dict

test_step(data: dict) mmedit.utils.typing.SampleList

Gets the generated image of given data. Same as val_step().

参数

data (dict) – Data sampled from metric specific sampler. More detials in Metrics and Evaluator.

返回

Generated image or image dict.

返回类型

List[EditDataSample]

val_step(data: dict) mmedit.utils.typing.SampleList

Gets the generated image of given data. Same as val_step().

参数

data (dict) – Data sampled from metric specific sampler. More detials in Metrics and Evaluator.

返回

Generated image or image dict.

返回类型

List[EditDataSample]

class mmedit.models.editors.pix2pix.UnetGenerator(in_channels, out_channels, num_down=8, base_channels=64, norm_cfg=dict(type='BN'), use_dropout=False, init_cfg=dict(type='normal', gain=0.02))[源代码]

Bases: torch.nn.Module

Construct the Unet-based generator from the innermost layer to the outermost layer, which is a recursive process.

参数
  • in_channels (int) – Number of channels in input images.

  • out_channels (int) – Number of channels in output images.

  • num_down (int) – Number of downsamplings in Unet. If num_down is 8, the image with size 256x256 will become 1x1 at the bottleneck. Default: 8.

  • base_channels (int) – Number of channels at the last conv layer. Default: 64.

  • norm_cfg (dict) – Config dict to build norm layer. Default: dict(type=’BN’).

  • use_dropout (bool) – Whether to use dropout layers. Default: False.

  • init_cfg (dict) – Config dict for initialization. type: The name of our initialization method. Default: ‘normal’. gain: Scaling factor for normal, xavier and orthogonal. Default: 0.02.

forward(x)

Forward function.

参数

x (Tensor) – Input tensor with shape (n, c, h, w).

返回

Forward results.

返回类型

Tensor

init_weights(pretrained=None, strict=True)

Initialize weights for the model.

参数
  • pretrained (str, optional) – Path for pretrained weights. If given None, pretrained weights will not be loaded. Default: None.

  • strict (bool, optional) – Whether to allow different params for the model and checkpoint. Default: True.

Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.