Shortcuts

mmedit.engine.hooks.ema

Module Contents

Classes

ExponentialMovingAverageHook

Exponential Moving Average Hook.

Attributes

DATA_BATCH

mmedit.engine.hooks.ema.DATA_BATCH[source]
class mmedit.engine.hooks.ema.ExponentialMovingAverageHook(module_keys, interp_mode='lerp', interp_cfg=None, interval=- 1, start_iter=0)[source]

Bases: mmengine.hooks.Hook

Exponential Moving Average Hook.

Exponential moving average is a trick that widely used in current GAN literature, e.g., PGGAN, StyleGAN, and BigGAN. This general idea of it is maintaining a model with the same architecture, but its parameters are updated as a moving average of the trained weights in the original model. In general, the model with moving averaged weights achieves better performance.

Parameters
  • module_keys (str | tuple[str]) – The name of the ema model. Note that we require these keys are followed by ‘_ema’ so that we can easily find the original model by discarding the last four characters.

  • interp_mode (str, optional) – Mode of the interpolation method. Defaults to ‘lerp’.

  • interp_cfg (dict | None, optional) – Set arguments of the interpolation function. Defaults to None.

  • interval (int, optional) – Evaluation interval (by iterations). Default: -1.

  • start_iter (int, optional) – Start iteration for ema. If the start iteration is not reached, the weights of ema model will maintain the same as the original one. Otherwise, its parameters are updated as a moving average of the trained weights in the original model. Default: 0.

static lerp(a, b, momentum=0.001, momentum_nontrainable=1.0, trainable=True)[source]

Does a linear interpolation of two parameters/ buffers.

Parameters
  • a (torch.Tensor) – Interpolation start point, refer to orig state.

  • b (torch.Tensor) – Interpolation end point, refer to ema state.

  • momentum (float, optional) – The weight for the interpolation formula. Defaults to 0.001.

  • momentum_nontrainable (float, optional) – The weight for the interpolation formula used for nontrainable parameters. Defaults to 1..

  • trainable (bool, optional) – Whether input parameters is trainable. If set to False, momentum_nontrainable will be used. Defaults to True.

Returns

Interpolation result.

Return type

torch.Tensor

every_n_iters(runner: mmengine.runner.Runner, n: int)[source]

This is the function to perform every n iterations.

Parameters
  • runner (Runner) – runner used to drive the whole pipeline

  • n (int) – the number of iterations

Returns

the latest iterations

Return type

int

after_train_iter(runner: mmengine.runner.Runner, batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[dict] = None) None[source]

This is the function to perform after each training iteration.

Parameters
  • runner (Runner) – runner to drive the pipeline

  • batch_idx (int) – the id of batch

  • data_batch (DATA_BATCH, optional) – data batch. Defaults to None.

  • outputs (Optional[dict], optional) – output. Defaults to None.

before_run(runner: mmengine.runner.Runner)[source]

This is the function perform before each run.

Parameters

runner (Runner) – runner used to drive the whole pipeline

Raises

RuntimeError – error message

Read the Docs v: latest
Versions
master
latest
stable
zyh-re-docs
zyh-doc-notfound-extend
zyh-api-rendering
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.