Shortcuts

BaseGAN

class mmedit.models.base_models.BaseGAN(generator: Union[Dict, torch.nn.modules.module.Module], discriminator: Optional[Union[Dict, torch.nn.modules.module.Module]] = None, data_preprocessor: Optional[Union[dict, mmengine.config.config.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, noise_size: Optional[int] = None, ema_config: Optional[Dict] = None, loss_config: Optional[Dict] = None)[source]

Base class for GAN models.

Parameters
  • generator (ModelType) – The config or model of the generator.

  • discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.

  • data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or GenDataPreprocessor.

  • generator_steps (int) – The number of times the generator is completely updated before the discriminator is updated. Defaults to 1.

  • discriminator_steps (int) – The number of times the discriminator is completely updated before the generator is updated. Defaults to 1.

  • ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.

property device: torch.device

Get current device of the model.

Returns

The current device of the model.

Return type

torch.device

property discriminator_steps: int

The number of times the discriminator is completely updated before the generator is updated.

Type

int

forward(inputs: Tuple[Dict[str, Union[torch.Tensor, str, int]], torch.Tensor], data_samples: Optional[list] = None, mode: Optional[str] = None) Sequence[mmengine.structures.base_data_element.BaseDataElement][source]

Sample images with the given inputs. If forward mode is ‘ema’ or ‘orig’, the image generated by corresponding generator will be returned. If forward mode is ‘ema/orig’, images generated by original generator and EMA generator will both be returned in a dict.

Parameters
  • batch_inputs (ForwardInputs) – Dict containing the necessary information (e.g. noise, num_batches, mode) to generate image.

  • data_samples (Optional[list]) – Data samples collated by data_preprocessor. Defaults to None.

  • mode (Optional[str]) – mode is not used in BaseGAN. Defaults to None.

Returns

A list of EditDataSample contain generated results.

Return type

SampleList

static gather_log_vars(log_vars_list: List[Dict[str, torch.Tensor]]) Dict[str, torch.Tensor][source]

Gather a list of log_vars. :param log_vars_list: List[Dict[str, Tensor]]

Returns

Dict[str, Tensor]

property generator_steps: int

The number of times the generator is completely updated before the discriminator is updated.

Type

int

noise_fn(noise: Optional[Union[torch.Tensor, Callable]] = None, num_batches: int = 1)[source]

Sampling function for noise. There are three scenarios in this function:

  • If noise is a callable function, sample num_batches of noise with passed noise.

  • If noise is None, sample num_batches of noise from gaussian distribution.

  • If noise is a torch.Tensor, directly return noise.

Parameters
  • noise (Union[Tensor, Callable, List[int], None]) – You can directly give a batch of label through a torch.Tensor or offer a callable function to sample a batch of label data. Otherwise, the None indicates to use the default noise sampler. Defaults to None.

  • num_batches (int, optional) – The number of batches label want to sample. If label is a Tensor, this will be ignored. Defaults to 1.

Returns

Sampled noise tensor.

Return type

Tensor

test_step(data: dict) Sequence[mmengine.structures.base_data_element.BaseDataElement][source]

Gets the generated image of given data. Same as val_step().

Parameters

data (dict) – Data sampled from metric specific sampler. More detials in Metrics and Evaluator.

Returns

Generated image or image dict.

Return type

List[EditDataSample]

train_discriminator(inputs: dict, data_samples: List[mmedit.structures.edit_data_sample.EditDataSample], optimizer_wrapper: mmengine.optim.optimizer.optimizer_wrapper.OptimWrapper) Dict[str, torch.Tensor][source]

Training function for discriminator. All GANs should implement this function by themselves.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

train_generator(inputs: dict, data_samples: List[mmedit.structures.edit_data_sample.EditDataSample], optimizer_wrapper: mmengine.optim.optimizer.optimizer_wrapper.OptimWrapper) Dict[str, torch.Tensor][source]

Training function for discriminator. All GANs should implement this function by themselves.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

train_step(data: dict, optim_wrapper: mmengine.optim.optimizer.optimizer_wrapper_dict.OptimWrapperDict) Dict[str, torch.Tensor][source]

Train GAN model. In the training of GAN models, generator and discriminator are updated alternatively. In MMEditing’s design, self.train_step is called with data input. Therefore we always update discriminator, whose updating is relay on real data, and then determine if the generator needs to be updated based on the current number of iterations. More details about whether to update generator can be found in should_gen_update().

Parameters
  • data (dict) – Data sampled from dataloader.

  • optim_wrapper (OptimWrapperDict) – OptimWrapperDict instance contains OptimWrapper of generator and discriminator.

Returns

A dict of tensor for logging.

Return type

Dict[str, torch.Tensor]

val_step(data: dict) Sequence[mmengine.structures.base_data_element.BaseDataElement][source]

Gets the generated image of given data.

Calls self.data_preprocessor(data) and self(inputs, data_sample, mode=None) in order. Return the generated results which will be passed to evaluator.

Parameters

data (dict) – Data sampled from metric specific sampler. More detials in Metrics and Evaluator.

Returns

Generated image or image dict.

Return type

SampleList

property with_ema_gen: bool

Whether the GAN adopts exponential moving average.

Returns

If True, means this GAN model is adopted to exponential

moving average and vice versa.

Return type

bool

Read the Docs v: zyh/doc-notfound-extend
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.