Shortcuts

ExponentialMovingAverageHook

class mmedit.engine.hooks.ExponentialMovingAverageHook(module_keys, interp_mode='lerp', interp_cfg=None, interval=- 1, start_iter=0)[source]

Exponential Moving Average Hook.

Exponential moving average is a trick that widely used in current GAN literature, e.g., PGGAN, StyleGAN, and BigGAN. This general idea of it is maintaining a model with the same architecture, but its parameters are updated as a moving average of the trained weights in the original model. In general, the model with moving averaged weights achieves better performance.

Parameters
  • module_keys (str | tuple[str]) – The name of the ema model. Note that we require these keys are followed by ‘_ema’ so that we can easily find the original model by discarding the last four characters.

  • interp_mode (str, optional) – Mode of the interpolation method. Defaults to ‘lerp’.

  • interp_cfg (dict | None, optional) – Set arguments of the interpolation function. Defaults to None.

  • interval (int, optional) – Evaluation interval (by iterations). Default: -1.

  • start_iter (int, optional) – Start iteration for ema. If the start iteration is not reached, the weights of ema model will maintain the same as the original one. Otherwise, its parameters are updated as a moving average of the trained weights in the original model. Default: 0.

after_train_iter(runner: mmengine.runner.runner.Runner, batch_idx: int, data_batch: Optional[Sequence[dict]] = None, outputs: Optional[dict] = None) None[source]

This is the function to perform after each training iteration.

Parameters
  • runner (Runner) – runner to drive the pipeline

  • batch_idx (int) – the id of batch

  • data_batch (DATA_BATCH, optional) – data batch. Defaults to None.

  • outputs (Optional[dict], optional) – output. Defaults to None.

before_run(runner: mmengine.runner.runner.Runner)[source]

This is the function perform before each run.

Parameters

runner (Runner) – runner used to drive the whole pipeline

Raises

RuntimeError – error message

every_n_iters(runner: mmengine.runner.runner.Runner, n: int)[source]

This is the function to perform every n iterations.

Parameters
  • runner (Runner) – runner used to drive the whole pipeline

  • n (int) – the number of iterations

Returns

the latest iterations

Return type

int

static lerp(a, b, momentum=0.999, momentum_nontrainable=0.0, trainable=True)[source]

This is the function to perform linear interpolation between a and b.

Parameters
  • a (float) – number a

  • b (float) – bumber b

  • momentum (float, optional) – momentum. Defaults to 0.999.

  • momentum_nontrainable (float, optional) – Defaults to 0.

  • trainable (bool, optional) – trainable flag. Defaults to True.

Returns

_description_

Return type

_type_

Read the Docs v: zyh/doc-notfound-extend
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.