Shortcuts

mmedit.models.utils.model_utils

Module Contents

Functions

default_init_weights(module[, scale])

Initialize network weights.

make_layer(block, num_blocks, **kwarg)

Make layers by stacking the same blocks.

get_module_device(module)

Get the device of a module.

set_requires_grad(nets[, requires_grad])

Set requires_grad for all the networks.

generation_init_weights(module[, init_type, init_gain])

Default initialization of network weights for image generation.

get_valid_noise_size(→ Optional[int])

Get the value of noise_size from input, generator and check the

get_valid_num_batches(→ int)

Try get the valid batch size from inputs.

mmedit.models.utils.model_utils.default_init_weights(module, scale=1)[source]

Initialize network weights.

Parameters
  • modules (nn.Module) – Modules to be initialized.

  • scale (float) – Scale initialized weights, especially for residual blocks. Default: 1.

mmedit.models.utils.model_utils.make_layer(block, num_blocks, **kwarg)[source]

Make layers by stacking the same blocks.

Parameters
  • block (nn.module) – nn.module class for basic block.

  • num_blocks (int) – number of blocks.

Returns

Stacked blocks in nn.Sequential.

Return type

nn.Sequential

mmedit.models.utils.model_utils.get_module_device(module)[source]

Get the device of a module.

Parameters

module (nn.Module) – A module contains the parameters.

Returns

The device of the module.

Return type

torch.device

mmedit.models.utils.model_utils.set_requires_grad(nets, requires_grad=False)[source]

Set requires_grad for all the networks.

Parameters
  • nets (nn.Module | list[nn.Module]) – A list of networks or a single network.

  • requires_grad (bool) – Whether the networks require gradients or not

mmedit.models.utils.model_utils.generation_init_weights(module, init_type='normal', init_gain=0.02)[source]

Default initialization of network weights for image generation.

By default, we use normal init, but xavier and kaiming might work better for some applications.

Parameters
  • module (nn.Module) – Module to be initialized.

  • init_type (str) – The name of an initialization method: normal | xavier | kaiming | orthogonal. Default: ‘normal’.

  • init_gain (float) – Scaling factor for normal, xavier and orthogonal. Default: 0.02.

mmedit.models.utils.model_utils.get_valid_noise_size(noise_size: Optional[int], generator: Union[Dict, torch.nn.Module]) Optional[int][source]

Get the value of noise_size from input, generator and check the consistency of these values. If no conflict is found, return that value.

Parameters
  • noise_size (Optional[int]) – noise_size passed to BaseGAN_refactor’s initialize function.

  • generator (ModelType) – The config or the model of generator.

Returns

The noise size feed to generator.

Return type

int | None

mmedit.models.utils.model_utils.get_valid_num_batches(batch_inputs: mmedit.utils.typing.ForwardInputs) int[source]

Try get the valid batch size from inputs.

  • If some values in batch_inputs are Tensor and ‘num_batches’ is in batch_inputs, we check whether the value of ‘num_batches’ and the the length of first dimension of all tensors are same. If the values are not same, AssertionError will be raised. If all values are the same, return the value.

  • If no values in batch_inputs is Tensor, ‘num_batches’ must be contained in batch_inputs. And this value will be returned.

  • If some values are Tensor and ‘num_batches’ is not contained in batch_inputs, we check whether all tensor have the same length on the first dimension. If the length are not same, AssertionError will be raised. If all length are the same, return the length as batch size.

  • If batch_inputs is a Tensor, directly return the length of the first dimension as batch size.

Parameters

batch_inputs (ForwardInputs) – Inputs passed to forward().

Returns

The batch size of samples to generate.

Return type

int

Read the Docs v: latest
Versions
master
latest
stable
zyh-re-docs
zyh-doc-notfound-extend
zyh-api-rendering
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.