Shortcuts

Source code for mmedit.models.editors.liif.mlp_refiner

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmengine.model import BaseModule

from mmedit.registry import MODELS


@MODELS.register_module()
[docs]class MLPRefiner(BaseModule): """Multilayer perceptrons (MLPs), refiner used in LIIF. Args: in_dim (int): Input dimension. out_dim (int): Output dimension. hidden_list (list[int]): List of hidden dimensions. """ def __init__(self, in_dim, out_dim, hidden_list): super().__init__() layers = [] last_channels = in_dim for hidden in hidden_list: layers.append(nn.Linear(last_channels, hidden)) layers.append(nn.ReLU()) last_channels = hidden layers.append(nn.Linear(last_channels, out_dim)) self.layers = nn.Sequential(*layers)
[docs] def forward(self, x): """Forward function. Args: x (Tensor): The input of MLP. Returns: Tensor: The output of MLP. """ shape = x.shape[:-1] x = self.layers(x.view(-1, x.shape[-1])) return x.view(*shape, -1)
Read the Docs v: latest
Versions
master
latest
stable
zyh-re-docs
zyh-doc-notfound-extend
zyh-api-rendering
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.