Shortcuts

mmedit.models.editors.sagan.sagan_discriminator

Module Contents

Classes

ProjDiscriminator

Discriminator for SNGAN / Proj-GAN. The implementation is refer to

class mmedit.models.editors.sagan.sagan_discriminator.ProjDiscriminator(input_scale, num_classes=0, base_channels=128, input_channels=3, attention_cfg=dict(type='SelfAttentionBlock'), attention_after_nth_block=- 1, channels_cfg=None, downsample_cfg=None, from_rgb_cfg=dict(type='SNGANDiscHeadResBlock'), blocks_cfg=dict(type='SNGANDiscResBlock'), act_cfg=dict(type='ReLU'), with_spectral_norm=True, sn_style='torch', sn_eps=1e-12, init_cfg=dict(type='BigGAN'), pretrained=None)[源代码]

Bases: torch.nn.Module

Discriminator for SNGAN / Proj-GAN. The implementation is refer to https://github.com/pfnet-research/sngan_projection/tree/master/dis_models

The overall structure of the projection discriminator can be split into a from_rgb layer, a group of ResBlocks, a linear decision layer, and a projection layer. To support defining custom layers, we introduce from_rgb_cfg and blocks_cfg.

The design of the model structure is highly corresponding to the output resolution. Therefore, we provide channels_cfg and downsample_cfg to control the input channels and the downsample behavior of the intermedia blocks.

downsample_cfg: In default config of SNGAN / Proj-GAN, whether to apply

downsample in each intermedia blocks is quite flexible and corresponding to the resolution of the output image. Therefore, we support user to define the downsample_cfg by themselves, and to control the structure of the discriminator.

channels_cfg: In default config of SNGAN / Proj-GAN, the number of

ResBlocks and the channels of those blocks are corresponding to the resolution of the output image. Therefore, we allow user to define channels_cfg for try their own models. We also provide a default config to allow users to build the model only from the output resolution.

参数
  • input_scale (int) – The scale of the input image.

  • num_classes (int, optional) – The number classes you would like to generate. If num_classes=0, no label projection would be used. Default to 0.

  • base_channels (int, optional) – The basic channel number of the discriminator. The other layers contains channels based on this number. Defaults to 128.

  • input_channels (int, optional) – Channels of the input image. Defaults to 3.

  • attention_cfg (dict, optional) – Config for the self-attention block. Default to dict(type='SelfAttentionBlock').

  • attention_after_nth_block (int | list[int], optional) – Self-attention block would be added after which ConvBlock (including the head block). If int is passed, only one attention block would be added. If list is passed, self-attention blocks would be added after multiple ConvBlocks. To be noted that if the input is smaller than 1, self-attention corresponding to this index would be ignored. Default to 0.

  • channels_cfg (list | dict[list], optional) – Config for input channels of the intermedia blocks. If list is passed, each element of the list means the input channels of current block is how many times compared to the base_channels. For block i, the input and output channels should be channels_cfg[i] and channels_cfg[i+1] If dict is provided, the key of the dict should be the output scale and corresponding value should be a list to define channels. Default: Please refer to _defualt_channels_cfg.

  • downsample_cfg (list[bool] | dict[list], optional) – Config for downsample behavior of the intermedia layers. If a list is passed, downsample_cfg[idx] == True means apply downsample in idx-th block, and vice versa. If dict is provided, the key dict should be the input scale of the image and corresponding value should be a list ti define the downsample behavior. Default: Please refer to _default_downsample_cfg.

  • from_rgb_cfg (dict, optional) – Config for the first layer to convert rgb image to feature map. Defaults to dict(type='SNGANDiscHeadResBlock').

  • blocks_cfg (dict, optional) – Config for the intermedia blocks. Defaults to dict(type='SNGANDiscResBlock')

  • act_cfg (dict, optional) – Activation config for the final output layer. Defaults to dict(type='ReLU').

  • with_spectral_norm (bool, optional) – Whether use spectral norm for all conv blocks or not. Default to True.

  • sn_style (str, optional) – The style of spectral normalization. If set to ajbrock, implementation by ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py) will be adopted. If set to torch, implementation by PyTorch will be adopted. Defaults to torch.

  • sn_eps (float, optional) – eps for spectral normalization operation. Defaults to 1e-12.

  • init_cfg (dict, optional) – Config for weight initialization. Default to dict(type='BigGAN').

  • pretrained (str | dict , optional) – Path for the pretrained model or dict containing information for pretained models whose necessary key is ‘ckpt_path’. Besides, you can also provide ‘prefix’ to load the generator part from the whole state dict. Defaults to None.

_defualt_channels_cfg[源代码]
_defualt_downsample_cfg[源代码]
forward(x, label=None)[源代码]

Forward function. If self.num_classes is larger than 0, label projection would be used.

参数
  • x (torch.Tensor) – Fake or real image tensor.

  • label (torch.Tensor, options) – Label correspond to the input image. Noted that, if self.num_classed is larger than 0, label should not be None. Default to None.

返回

Prediction for the reality of the input image.

返回类型

torch.Tensor

init_weights(pretrained=None, strict=True)[源代码]

Init weights for SNGAN-Proj and SAGAN. If pretrained=None and weight initialization would follow the INIT_TYPE in init_cfg=dict(type=INIT_TYPE).

For SNGAN-Proj (INIT_TYPE.upper() in ['SNGAN', 'SNGAN-PROJ', 'GAN-PROJ']), we follow the initialization method in the official Chainer’s implementation (https://github.com/pfnet-research/sngan_projection).

For SAGAN (INIT_TYPE.upper() == 'SAGAN'), we follow the initialization method in official tensorflow’s implementation (https://github.com/brain-research/self-attention-gan).

Besides the reimplementation of the official code’s initialization, we provide BigGAN’s and Pytorch-StudioGAN’s style initialization (INIT_TYPE.upper() == BIGGAN and INIT_TYPE.upper() == STUDIO). Please refer to https://github.com/ajbrock/BigGAN-PyTorch and https://github.com/POSTECH-CVLab/PyTorch-StudioGAN.

参数

pretrained (str | dict, optional) – Path for the pretrained model or dict containing information for pretained models whose necessary key is ‘ckpt_path’. Besides, you can also provide ‘prefix’ to load the generator part from the whole state dict. Defaults to None.

Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.