Shortcuts

Source code for mmedit.models.base_archs.vgg

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional

import torch.nn as nn
from mmengine.model import BaseModule
from mmengine.model.weight_init import constant_init, xavier_init
from torch import Tensor

from mmedit.registry import MODELS
from ..base_archs.aspp import ASPP


@MODELS.register_module()
[docs]class VGG16(BaseModule): """Customized VGG16 Encoder. A 1x1 conv is added after the original VGG16 conv layers. The indices of max pooling layers are returned for unpooling layers in decoders. Args: in_channels (int): Number of input channels. batch_norm (bool, optional): Whether use ``nn.BatchNorm2d``. Default to False. aspp (bool, optional): Whether use ASPP module after the last conv layer. Default to False. dilations (list[int], optional): Atrous rates of ASPP module. Default to None. init_cfg (dict, optional): Initialization config dict. """ def __init__(self, in_channels: int, batch_norm: Optional[bool] = False, aspp: Optional[bool] = False, dilations: Optional[List[int]] = None, init_cfg: Optional[dict] = None): super().__init__(init_cfg=init_cfg) self.batch_norm = batch_norm self.aspp = aspp self.dilations = dilations self.layer1 = self._make_layer(in_channels, 64, 2) self.layer2 = self._make_layer(64, 128, 2) self.layer3 = self._make_layer(128, 256, 3) self.layer4 = self._make_layer(256, 512, 3) self.layer5 = self._make_layer(512, 512, 3) self.conv6 = nn.Conv2d(512, 512, kernel_size=1) if self.batch_norm: self.bn = nn.BatchNorm2d(512) self.relu = nn.ReLU(inplace=True) if self.aspp: self.aspp = ASPP(512, dilations=self.dilations) self.out_channels = 256 else: self.out_channels = 512
[docs] def _make_layer(self, inplanes: int, planes: int, convs_layers: int) -> nn.Module: layers = [] for _ in range(convs_layers): conv2d = nn.Conv2d(inplanes, planes, kernel_size=3, padding=1) if self.batch_norm: bn = nn.BatchNorm2d(planes) layers += [conv2d, bn, nn.ReLU(inplace=True)] else: layers += [conv2d, nn.ReLU(inplace=True)] inplanes = planes layers += [nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)] return nn.Sequential(*layers)
[docs] def init_weights(self) -> None: """Init weights for the model.""" if self.init_cfg is not None: super().init_weights() else: # Default initialization for m in self.modules(): if isinstance(m, nn.Conv2d): xavier_init(m) elif isinstance(m, nn.BatchNorm2d): constant_init(m, 1)
[docs] def forward(self, x: Tensor) -> Dict[str, Tensor]: """Forward function for ASPP module. Args: x (Tensor): Input tensor with shape (N, C, H, W). Returns: dict: Dict containing output tensor and maxpooling indices. """ out, max_idx_1 = self.layer1(x) out, max_idx_2 = self.layer2(out) out, max_idx_3 = self.layer3(out) out, max_idx_4 = self.layer4(out) out, max_idx_5 = self.layer5(out) out = self.conv6(out) if self.batch_norm: out = self.bn(out) out = self.relu(out) if self.aspp: out = self.aspp(out) return { 'out': out, 'max_idx_1': max_idx_1, 'max_idx_2': max_idx_2, 'max_idx_3': max_idx_3, 'max_idx_4': max_idx_4, 'max_idx_5': max_idx_5
}
Read the Docs v: latest
Versions
master
latest
stable
zyh-re-docs
zyh-doc-notfound-extend
zyh-api-rendering
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.