Shortcuts

mmedit.models.base_models.base_conditional_gan

Module Contents

Classes

BaseConditionalGAN

Base class for Conditional GAM models.

Attributes

ModelType

mmedit.models.base_models.base_conditional_gan.ModelType[源代码]
class mmedit.models.base_models.base_conditional_gan.BaseConditionalGAN(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, noise_size: Optional[int] = None, num_classes: Optional[int] = None, ema_config: Optional[Dict] = None, loss_config: Optional[Dict] = None)[源代码]

Bases: mmedit.models.base_models.base_gan.BaseGAN

Base class for Conditional GAM models.

参数
  • generator (ModelType) – The config or model of the generator.

  • discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.

  • data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or GenDataPreprocessor.

  • generator_steps (int) – The number of times the generator is completely updated before the discriminator is updated. Defaults to 1.

  • discriminator_steps (int) – The number of times the discriminator is completely updated before the generator is updated. Defaults to 1.

  • noise_size (Optional[int]) – Size of the input noise vector. Default to None.

  • num_classes (Optional[int]) – The number classes you would like to generate. Defaults to None.

  • ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.

label_fn(label: mmedit.utils.typing.LabelVar = None, num_batches: int = 1) torch.Tensor[源代码]

Sampling function for label. There are three scenarios in this function:

  • If label is a callable function, sample num_batches of labels with passed label.

  • If label is None, sample num_batches of labels in range of [0, self.num_classes-1] uniformly.

  • If label is a torch.Tensor, check the range of the tensor is in [0, self.num_classes-1]. If all values are in valid range, directly return label.

参数
  • label (Union[Tensor, Callable, List[int], None]) – You can directly give a batch of label through a torch.Tensor or offer a callable function to sample a batch of label data. Otherwise, the None indicates to use the default label sampler. Defaults to None.

  • num_batches (int, optional) – The number of batches label want to sample. If label is a Tensor, this will be ignored. Defaults to 1.

返回

Sampled label tensor.

返回类型

Tensor

data_sample_to_label(data_sample: List[mmedit.structures.EditDataSample]) Optional[torch.Tensor][源代码]

Get labels from input data_sample and pack to torch.Tensor. If no label is found in the passed data_sample, None would be returned.

参数

data_sample (List[EditDataSample]) – Input data samples.

返回

Packed label tensor.

返回类型

Optional[torch.Tensor]

static _get_valid_num_classes(num_classes: Optional[int], generator: ModelType, discriminator: Optional[ModelType]) int[源代码]

Try to get the value of num_classes from input, generator and discriminator and check the consistency of these values. If no conflict is found, return the num_classes.

参数
  • num_classes (Optional[int]) – num_classes passed to BaseConditionalGAN_refactor’s initialize function.

  • generator (ModelType) – The config or the model of generator.

  • discriminator (Optional[ModelType]) – The config or model of discriminator.

返回

The number of classes to be generated.

返回类型

int

forward(inputs: mmedit.utils.typing.ForwardInputs, data_samples: Optional[list] = None, mode: Optional[str] = None) List[mmedit.structures.EditDataSample][源代码]

Sample images with the given inputs. If forward mode is ‘ema’ or ‘orig’, the image generated by corresponding generator will be returned. If forward mode is ‘ema/orig’, images generated by original generator and EMA generator will both be returned in a dict.

参数
  • inputs (ForwardInputs) – Dict containing the necessary information (e.g. noise, num_batches, mode) to generate image.

  • data_samples (Optional[list]) – Data samples collated by data_preprocessor. Defaults to None.

  • mode (Optional[str]) – mode is not used in BaseConditionalGAN. Defaults to None.

返回

Generated images or image dict.

返回类型

List[EditDataSample]

train_generator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][源代码]

Training function for discriminator. All GANs should implement this function by themselves.

参数
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

返回

A dict of tensor for logging.

返回类型

Dict[str, Tensor]

train_discriminator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][源代码]

Training function for discriminator. All GANs should implement this function by themselves.

参数
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

返回

A dict of tensor for logging.

返回类型

Dict[str, Tensor]

Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.