Shortcuts

mmedit.models.editors.singan.singan

Module Contents

Classes

SinGAN

SinGAN.

Attributes

ModelType

TrainInput

mmedit.models.editors.singan.singan.ModelType[source]
mmedit.models.editors.singan.singan.TrainInput[source]
class mmedit.models.editors.singan.singan.SinGAN(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, num_scales: Optional[int] = None, iters_per_scale: int = 2000, noise_weight_init: int = 0.1, lr_scheduler_args: Optional[dict] = None, test_pkl_data: Optional[str] = None, ema_confg: Optional[dict] = None)[source]

Bases: mmedit.models.base_models.BaseGAN

SinGAN.

This model implement the single image generative adversarial model proposed in: Singan: Learning a Generative Model from a Single Natural Image, ICCV’19.

Notes for training:

  • This model should be trained with our dataset SinGANDataset.

  • In training, the total_iters arguments is related to the number of scales in the image pyramid and iters_per_scale in the train_cfg. You should set it carefully in the training config file.

Notes for model architectures:

  • The generator and discriminator need num_scales in initialization. However, this arguments is generated by create_real_pyramid function from the singan_dataset.py. The last element in the returned list (stop_scale) is the value for num_scales. Pay attention that this scale is counted from zero. Please see our tutorial for SinGAN to obtain more details or our standard config for reference.

Parameters
  • generator (ModelType) – The config or model of the generator.

  • discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.

  • data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or GenDataPreprocessor.

  • generator_steps (int) – The number of times the generator is completely updated before the discriminator is updated. Defaults to 1.

  • discriminator_steps (int) – The number of times the discriminator is completely updated before the generator is updated. Defaults to 1.

  • num_scales (int) – The number of scales/stages in generator/ discriminator. Note that this number is counted from zero, which is the same as the original paper. Defaults to None.

  • iters_per_scale (int) – The training iteration for each resolution scale. Defaults to 2000.

  • noise_weight_init (float) – The initialize weight of fixed noise. Defaults to 0.1

  • lr_scheduler_args (Optional[dict]) – Arguments for learning schedulers. Note that in SinGAN, we use MultiStepLR, which is the same as the original paper. If not passed, no learning schedule will be used. Defaults to None.

  • test_pkl_data (Optional[str]) – The path of pickle file which contains fixed noise and noise weight. This is must for test. Defaults to None.

  • ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.

load_test_pkl()[source]

Load pickle for test.

_from_numpy(data: Tuple[list, numpy.ndarray]) Tuple[torch.Tensor, List[torch.Tensor]][source]

Convert input numpy array or list of numpy array to Tensor or list of Tensor.

Parameters

data (Tuple[list, np.ndarray]) – Input data to convert.

Returns

Converted Tensor or list of tensor.

Return type

Tuple[Tensor, List[Tensor]]

get_module(model: torch.nn.Module, module_name: str) torch.nn.Module[source]

Get an inner module from model.

Since we will wrapper DDP for some model, we have to judge whether the module can be indexed directly.

Parameters
  • model (nn.Module) – This model may wrapped with DDP or not.

  • module_name (str) – The name of specific module.

Returns

Returned sub module.

Return type

nn.Module

construct_fixed_noises()[source]

Construct the fixed noises list used in SinGAN.

forward(inputs: mmedit.utils.ForwardInputs, data_samples: Optional[list] = None, mode=None) List[mmedit.structures.EditDataSample][source]

Forward function for SinGAN. For SinGAN, inputs should be a dict contains ‘num_batches’, ‘mode’ and other input arguments for the generator.

Parameters
  • inputs (dict) – Dict containing the necessary information (e.g., noise, num_batches, mode) to generate image.

  • data_samples (Optional[list]) – Data samples collated by data_preprocessor. Defaults to None.

  • mode (Optional[str]) – mode is not used in BaseConditionalGAN. Defaults to None.

gen_loss(disc_pred_fake: torch.Tensor, recon_imgs: torch.Tensor) Tuple[torch.Tensor, dict][source]

Generator loss for SinGAN. SinGAN use WGAN’s loss and MSE loss to train the generator.

Parameters
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • recon_imgs (Tensor) – Reconstructive images.

Returns

Loss value and a dict of log variables.

Return type

Tuple[Tensor, dict]

disc_loss(disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor, fake_data: torch.Tensor, real_data: torch.Tensor) Tuple[torch.Tensor, dict][source]

Get disc loss. SAGAN, SNGAN and Proj-GAN use hinge loss to train the generator.

Parameters
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • disc_pred_real (Tensor) – Discriminator’s prediction of the real images.

  • fake_data (Tensor) – Generated images, used to calculate gradient penalty.

  • real_data (Tensor) – Real images, used to calculate gradient penalty.

Returns

Loss value and a dict of log variables.

Return type

Tuple[Tensor, dict]

train_generator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][source]

Train generator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

train_discriminator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][source]

Train discriminator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

train_gan(inputs_dict: dict, data_sample: List[mmedit.structures.EditDataSample], optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor][source]

Train GAN model. In the training of GAN models, generator and discriminator are updated alternatively. In MMGeneration’s design, self.train_step is called with data input. Therefore we always update discriminator, whose updating is relay on real data, and then determine if the generator needs to be updated based on the current number of iterations. More details about whether to update generator can be found in should_gen_update().

Parameters
  • data (dict) – Data sampled from dataloader.

  • data_sample (List[EditDataSample]) – List of data sample contains GT and meta information.

  • optim_wrapper (OptimWrapperDict) – OptimWrapperDict instance contains OptimWrapper of generator and discriminator.

Returns

A dict of tensor for logging.

Return type

Dict[str, torch.Tensor]

train_step(data: dict, optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor][source]

Train step for SinGAN model. SinGAN is trained with multi-resolution images, and each resolution is trained for :attr:self.iters_per_scale times.

We initialize the weight and learning rate scheduler of the corresponding module at the start of each resolution’s training. At the end of each resolution’s training, we update the weight of the noise of current resolution by mse loss between reconstruced image and real image.

Parameters
  • data (dict) – Data sampled from dataloader.

  • optim_wrapper (OptimWrapperDict) – OptimWrapperDict instance contains OptimWrapper of generator and discriminator.

Returns

A dict of tensor for logging.

Return type

Dict[str, torch.Tensor]

test_step(data: dict) mmedit.utils.SampleList[source]

Gets the generated image of given data in test progress. Before generate images, we call :meth:self.load_test_pkl to load the fixed noise and current stage of the model from the pickle file.

Parameters

data (dict) – Data sampled from metric specific sampler. More detials in Metrics and Evaluator.

Returns

A list of EditDataSample contain generated results.

Return type

SampleList

Read the Docs v: latest
Versions
master
latest
stable
zyh-re-docs
zyh-doc-notfound-extend
zyh-api-rendering
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.