Source code for mmedit.models.base_archs.multi_layer_disc

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine import MMLogger
from mmengine.runner import load_checkpoint
from torch import Tensor

from mmedit.models.base_archs import LinearModule
from mmedit.registry import MODELS

[docs]class MultiLayerDiscriminator(nn.Module): """Multilayer Discriminator. This is a commonly used structure with stacked multiply convolution layers. Args: in_channels (int): Input channel of the first input convolution. max_channels (int): The maximum channel number in this structure. num_conv (int): Number of stacked intermediate convs (including input conv but excluding output conv). Default to 5. fc_in_channels (int | None): Input dimension of the fully connected layer. If `fc_in_channels` is None, the fully connected layer will be removed. Default to None. fc_out_channels (int): Output dimension of the fully connected layer. Default to 1024. kernel_size (int): Kernel size of the conv modules. Default to 5. conv_cfg (dict): Config dict to build conv layer. norm_cfg (dict): Config dict to build norm layer. act_cfg (dict): Config dict for activation layer, "relu" by default. out_act_cfg (dict): Config dict for output activation, "relu" by default. with_input_norm (bool): Whether add normalization after the input conv. Default to True. with_out_convs (bool): Whether add output convs to the discriminator. The output convs contain two convs. The first out conv has the same setting as the intermediate convs but a stride of 1 instead of 2. The second out conv is a conv similar to the first out conv but reduces the number of channels to 1 and has no activation layer. Default to False. with_spectral_norm (bool): Whether use spectral norm after the conv layers. Default to False. kwargs (keyword arguments). """ def __init__(self, in_channels: int, max_channels: int, num_convs: int = 5, fc_in_channels: Optional[int] = None, fc_out_channels: int = 1024, kernel_size: int = 5, conv_cfg: Optional[dict] = None, norm_cfg: Optional[dict] = None, act_cfg: Optional[dict] = dict(type='ReLU'), out_act_cfg: Optional[dict] = dict(type='ReLU'), with_input_norm: bool = True, with_out_convs: bool = False, with_spectral_norm: bool = False, **kwargs): super().__init__() if fc_in_channels is not None: assert fc_in_channels > 0 self.max_channels = max_channels self.with_fc = fc_in_channels is not None self.num_convs = num_convs self.with_out_act = out_act_cfg is not None self.with_out_convs = with_out_convs cur_channels = in_channels for i in range(num_convs): out_ch = min(64 * 2**i, max_channels) norm_cfg_ = norm_cfg act_cfg_ = act_cfg if i == 0 and not with_input_norm: norm_cfg_ = None elif (i == num_convs - 1 and not self.with_fc and not self.with_out_convs): norm_cfg_ = None act_cfg_ = out_act_cfg self.add_module( f'conv{i + 1}', ConvModule( cur_channels, out_ch, kernel_size=kernel_size, stride=2, padding=kernel_size // 2, norm_cfg=norm_cfg_, act_cfg=act_cfg_, with_spectral_norm=with_spectral_norm, **kwargs)) cur_channels = out_ch if self.with_out_convs: cur_channels = min(64 * 2**(num_convs - 1), max_channels) out_ch = min(64 * 2**num_convs, max_channels) self.add_module( f'conv{num_convs + 1}', ConvModule( cur_channels, out_ch, kernel_size, stride=1, padding=kernel_size // 2, norm_cfg=norm_cfg, act_cfg=act_cfg, with_spectral_norm=with_spectral_norm, **kwargs)) self.add_module( f'conv{num_convs + 2}', ConvModule( out_ch, 1, kernel_size, stride=1, padding=kernel_size // 2, act_cfg=None, with_spectral_norm=with_spectral_norm, **kwargs)) if self.with_fc: self.fc = LinearModule( fc_in_channels, fc_out_channels, bias=True, act_cfg=out_act_cfg, with_spectral_norm=with_spectral_norm)
[docs] def forward(self, x: Tensor) -> Tensor: """Forward Function. Args: x (torch.Tensor): Input tensor with shape of (n, c, h, w). Returns: torch.Tensor: Output tensor with shape of (n, c, h', w') or (n, c). """ input_size = x.size() # out_convs has two additional ConvModules num_convs = self.num_convs + 2 * self.with_out_convs for i in range(num_convs): x = getattr(self, f'conv{i + 1}')(x) if self.with_fc: x = x.view(input_size[0], -1) x = self.fc(x) return x
[docs] def init_weights(self, pretrained: Optional[str] = None) -> None: """Init weights for models. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None. """ if isinstance(pretrained, str): logger = MMLogger.get_current_instance() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): # Here, we only initialize the module with fc layer since the # conv and norm layers has been initialized in `ConvModule`. if isinstance(m, nn.Linear): nn.init.normal_(, 0.0, 0.02) nn.init.constant_(, 0.0) else: raise TypeError('pretrained must be a str or None')
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.