Shortcuts

Source code for mmedit.models.base_archs.all_gather_layer

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.autograd as autograd
import torch.distributed as dist


[docs]class AllGatherLayer(autograd.Function): """All gather layer with backward propagation path. Indeed, this module is to make ``dist.all_gather()`` in the backward graph. Such kind of operation has been widely used in Moco and other contrastive learning algorithms. """ @staticmethod
[docs] def forward(ctx, x): """Forward function.""" ctx.save_for_backward(x) output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] dist.all_gather(output, x) return tuple(output)
@staticmethod
[docs] def backward(ctx, *grad_outputs): """Backward function.""" x, = ctx.saved_tensors grad_out = torch.zeros_like(x) grad_out = grad_outputs[dist.get_rank()] return grad_out
Read the Docs v: latest
Versions
master
latest
stable
zyh-re-docs
zyh-doc-notfound-extend
zyh-api-rendering
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.