Shortcuts

BaseConditionalGAN

class mmedit.models.base_models.BaseConditionalGAN(generator: Union[Dict, torch.nn.modules.module.Module], discriminator: Optional[Union[Dict, torch.nn.modules.module.Module]] = None, data_preprocessor: Optional[Union[dict, mmengine.config.config.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, noise_size: Optional[int] = None, num_classes: Optional[int] = None, ema_config: Optional[Dict] = None, loss_config: Optional[Dict] = None)[source]

Base class for Conditional GAM models.

Parameters
  • generator (ModelType) – The config or model of the generator.

  • discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.

  • data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or GenDataPreprocessor.

  • generator_steps (int) – The number of times the generator is completely updated before the discriminator is updated. Defaults to 1.

  • discriminator_steps (int) – The number of times the discriminator is completely updated before the generator is updated. Defaults to 1.

  • noise_size (Optional[int]) – Size of the input noise vector. Default to None.

  • num_classes (Optional[int]) – The number classes you would like to generate. Defaults to None.

  • ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.

data_sample_to_label(data_sample: List[mmedit.structures.edit_data_sample.EditDataSample]) Optional[torch.Tensor][source]

Get labels from input data_sample and pack to torch.Tensor. If no label is found in the passed data_sample, None would be returned.

Parameters

data_sample (List[EditDataSample]) – Input data samples.

Returns

Packed label tensor.

Return type

Optional[torch.Tensor]

forward(inputs: Tuple[Dict[str, Union[torch.Tensor, str, int]], torch.Tensor], data_samples: Optional[list] = None, mode: Optional[str] = None) List[mmedit.structures.edit_data_sample.EditDataSample][source]

Sample images with the given inputs. If forward mode is ‘ema’ or ‘orig’, the image generated by corresponding generator will be returned. If forward mode is ‘ema/orig’, images generated by original generator and EMA generator will both be returned in a dict.

Parameters
  • inputs (ForwardInputs) – Dict containing the necessary information (e.g. noise, num_batches, mode) to generate image.

  • data_samples (Optional[list]) – Data samples collated by data_preprocessor. Defaults to None.

  • mode (Optional[str]) – mode is not used in BaseConditionalGAN. Defaults to None.

Returns

Generated images or image dict.

Return type

List[EditDataSample]

label_fn(label: Optional[Union[torch.Tensor, Callable, List[int]]] = None, num_batches: int = 1) torch.Tensor[source]

Sampling function for label. There are three scenarios in this function:

  • If label is a callable function, sample num_batches of labels with passed label.

  • If label is None, sample num_batches of labels in range of [0, self.num_classes-1] uniformly.

  • If label is a torch.Tensor, check the range of the tensor is in [0, self.num_classes-1]. If all values are in valid range, directly return label.

Parameters
  • label (Union[Tensor, Callable, List[int], None]) – You can directly give a batch of label through a torch.Tensor or offer a callable function to sample a batch of label data. Otherwise, the None indicates to use the default label sampler. Defaults to None.

  • num_batches (int, optional) – The number of batches label want to sample. If label is a Tensor, this will be ignored. Defaults to 1.

Returns

Sampled label tensor.

Return type

Tensor

train_discriminator(inputs: dict, data_samples: List[mmedit.structures.edit_data_sample.EditDataSample], optimizer_wrapper: mmengine.optim.optimizer.optimizer_wrapper.OptimWrapper) Dict[str, torch.Tensor][source]

Training function for discriminator. All GANs should implement this function by themselves.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

train_generator(inputs: dict, data_samples: List[mmedit.structures.edit_data_sample.EditDataSample], optimizer_wrapper: mmengine.optim.optimizer.optimizer_wrapper.OptimWrapper) Dict[str, torch.Tensor][source]

Training function for discriminator. All GANs should implement this function by themselves.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

Read the Docs v: zyh/doc-notfound-extend
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.