Shortcuts

mmedit.evaluation.metrics.swd

Module Contents

Classes

SlicedWassersteinDistance

SWD (Sliced Wasserstein distance) metric. We calculate the SWD of two

Functions

sliced_wasserstein(distribution_a, distribution_b[, ...])

sliced Wasserstein distance of two sets of patches.

get_gaussian_kernel()

Get the gaussian blur kernel.

get_pyramid_layer(image, gaussian_k[, direction])

Get the pyramid layer.

gaussian_pyramid(original, n_pyramids, gaussian_k)

Get a group of gaussian pyramid.

laplacian_pyramid(original, n_pyramids, gaussian_k)

Calculate Laplacian pyramid.

get_descriptors_for_minibatch(minibatch, nhood_size, ...)

Get descriptors of one level of pyramids.

finalize_descriptors(desc)

Normalize and reshape descriptors.

mmedit.evaluation.metrics.swd.sliced_wasserstein(distribution_a, distribution_b, dir_repeats=4, dirs_per_repeat=128)[source]

sliced Wasserstein distance of two sets of patches.

Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/ms_ssim.py # noqa

Parameters
  • distribution_a (Tensor) – Descriptors of first distribution.

  • distribution_b (Tensor) – Descriptors of second distribution.

  • dir_repeats (int) – The number of projection times. Default to 4.

  • dirs_per_repeat (int) – The number of directions per projection. Default to 128.

Returns

sliced Wasserstein distance.

Return type

float

mmedit.evaluation.metrics.swd.get_gaussian_kernel()[source]

Get the gaussian blur kernel.

Returns

Blur kernel.

Return type

Tensor

mmedit.evaluation.metrics.swd.get_pyramid_layer(image, gaussian_k, direction='down')[source]

Get the pyramid layer.

Parameters
  • image (Tensor) – Input image.

  • gaussian_k (Tensor) – Gaussian kernel

  • direction (str, optional) – The direction of pyramid. Defaults to ‘down’.

Returns

The output of the pyramid.

Return type

Tensor

mmedit.evaluation.metrics.swd.gaussian_pyramid(original, n_pyramids, gaussian_k)[source]

Get a group of gaussian pyramid.

Parameters
  • original (Tensor) – The input image.

  • n_pyramids (int) – The number of pyramids.

  • gaussian_k (Tensor) – The gaussian kernel.

Returns

The list of output of gaussian pyramid.

Return type

List[Tensor]

mmedit.evaluation.metrics.swd.laplacian_pyramid(original, n_pyramids, gaussian_k)[source]

Calculate Laplacian pyramid.

Ref: https://github.com/koshian2/swd-pytorch/blob/master/swd.py

Parameters
  • original (Tensor) – Batch of Images with range [0, 1] and order “NCHW”

  • n_pyramids (int) – Levels of pyramids minus one.

  • gaussian_k (Tensor) – Gaussian kernel with shape (1, 1, 5, 5).

Returns

list[Tensor]. Laplacian pyramids of original.

mmedit.evaluation.metrics.swd.get_descriptors_for_minibatch(minibatch, nhood_size, nhoods_per_image)[source]

Get descriptors of one level of pyramids.

Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/sliced_wasserstein.py # noqa

Parameters
  • minibatch (Tensor) – Pyramids of one level with order “NCHW”.

  • nhood_size (int) – Pixel neighborhood size.

  • nhoods_per_image (int) – The number of descriptors per image.

Returns

Descriptors of images from one level batch.

Return type

Tensor

mmedit.evaluation.metrics.swd.finalize_descriptors(desc)[source]

Normalize and reshape descriptors.

Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/sliced_wasserstein.py # noqa

Parameters

desc (list or Tensor) – List of descriptors of one level.

Returns

Descriptors after normalized along channel and flattened.

Return type

Tensor

class mmedit.evaluation.metrics.swd.SlicedWassersteinDistance(fake_nums: int, image_shape: tuple, fake_key: Optional[str] = None, real_key: Optional[str] = 'img', sample_model: str = 'ema', collect_device: str = 'cpu', prefix: Optional[str] = None)[source]

Bases: mmedit.evaluation.metrics.base_gen_metric.GenMetric

SWD (Sliced Wasserstein distance) metric. We calculate the SWD of two sets of images in the following way. In every ‘feed’, we obtain the Laplacian pyramids of every images and extract patches from the Laplacian pyramids as descriptors. In ‘summary’, we normalize these descriptors along channel, and reshape them so that we can use these descriptors to represent the distribution of real/fake images. And we can calculate the sliced Wasserstein distance of the real and fake descriptors as the SWD of the real and fake images.

Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/sliced_wasserstein.py # noqa

Parameters
  • fake_nums (int) – Numbers of the generated image need for the metric.

  • image_shape (tuple) – Image shape in order “CHW”.

  • fake_key (Optional[str]) – Key for get fake images of the output dict. Defaults to None.

  • real_key (Optional[str]) – Key for get real images from the input dict. Defaults to ‘img’.

  • sample_model (str) – Sampling mode for the generative model. Support ‘orig’ and ‘ema’. Defaults to ‘ema’.

  • collect_device (str) – Device name used for collecting results from different ranks during distributed training. Must be ‘cpu’ or ‘gpu’. Defaults to ‘cpu’.

  • prefix (str, optional) – The prefix that will be added in the metric names to disambiguate homonymous metrics of different evaluators. If prefix is not provided in the argument, self.default_prefix will be used instead. Defaults to None.

name = SWD[source]
process(data_batch: dict, data_samples: Sequence[dict]) None[source]

Process one batch of data samples and predictions. The processed results should be stored in self.fake_results and self.real_results, which will be used to compute the metrics when all batches have been processed.

Parameters
  • data_batch (dict) – A batch of data from the dataloader.

  • data_samples (Sequence[dict]) – A batch of outputs from the model.

_collect_target_results(target: str) Optional[list][source]

Collect function for SWD metric. This function support collect results typing as List[List[Tensor]].

Parameters

target (str) – Target results to collect.

Returns

The collected results.

Return type

Optional[list]

compute_metrics(results_fake, results_real) dict[source]

Compulate the result of SWD metric.

Parameters
  • fake_results (list) – List of image feature of fake images.

  • real_results (list) – List of image feature of real images.

Returns

A dict of the computed SWD metric.

Return type

dict

Read the Docs v: latest
Versions
master
latest
stable
zyh-re-docs
zyh-doc-notfound-extend
zyh-api-rendering
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.