Shortcuts

mmedit.models.editors.singan.singan

Module Contents

Classes

SinGAN

SinGAN.

Attributes

ModelType

TrainInput

mmedit.models.editors.singan.singan.ModelType[源代码]
mmedit.models.editors.singan.singan.TrainInput[源代码]
class mmedit.models.editors.singan.singan.SinGAN(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, num_scales: Optional[int] = None, iters_per_scale: int = 2000, noise_weight_init: int = 0.1, lr_scheduler_args: Optional[dict] = None, test_pkl_data: Optional[str] = None, ema_confg: Optional[dict] = None)[源代码]

Bases: mmedit.models.base_models.BaseGAN

SinGAN.

This model implement the single image generative adversarial model proposed in: Singan: Learning a Generative Model from a Single Natural Image, ICCV’19.

Notes for training:

  • This model should be trained with our dataset SinGANDataset.

  • In training, the total_iters arguments is related to the number of scales in the image pyramid and iters_per_scale in the train_cfg. You should set it carefully in the training config file.

Notes for model architectures:

  • The generator and discriminator need num_scales in initialization. However, this arguments is generated by create_real_pyramid function from the singan_dataset.py. The last element in the returned list (stop_scale) is the value for num_scales. Pay attention that this scale is counted from zero. Please see our tutorial for SinGAN to obtain more details or our standard config for reference.

参数
  • generator (ModelType) – The config or model of the generator.

  • discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.

  • data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or GenDataPreprocessor.

  • generator_steps (int) – The number of times the generator is completely updated before the discriminator is updated. Defaults to 1.

  • discriminator_steps (int) – The number of times the discriminator is completely updated before the generator is updated. Defaults to 1.

  • num_scales (int) – The number of scales/stages in generator/ discriminator. Note that this number is counted from zero, which is the same as the original paper. Defaults to None.

  • iters_per_scale (int) – The training iteration for each resolution scale. Defaults to 2000.

  • noise_weight_init (float) – The initialize weight of fixed noise. Defaults to 0.1

  • lr_scheduler_args (Optional[dict]) – Arguments for learning schedulers. Note that in SinGAN, we use MultiStepLR, which is the same as the original paper. If not passed, no learning schedule will be used. Defaults to None.

  • test_pkl_data (Optional[str]) – The path of pickle file which contains fixed noise and noise weight. This is must for test. Defaults to None.

  • ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.

load_test_pkl()[源代码]

Load pickle for test.

_from_numpy(data: Tuple[list, numpy.ndarray]) Tuple[torch.Tensor, List[torch.Tensor]][源代码]

Convert input numpy array or list of numpy array to Tensor or list of Tensor.

参数

data (Tuple[list, np.ndarray]) – Input data to convert.

返回

Converted Tensor or list of tensor.

返回类型

Tuple[Tensor, List[Tensor]]

get_module(model: torch.nn.Module, module_name: str) torch.nn.Module[源代码]

Get an inner module from model.

Since we will wrapper DDP for some model, we have to judge whether the module can be indexed directly.

参数
  • model (nn.Module) – This model may wrapped with DDP or not.

  • module_name (str) – The name of specific module.

返回

Returned sub module.

返回类型

nn.Module

construct_fixed_noises()[源代码]

Construct the fixed noises list used in SinGAN.

forward(inputs: mmedit.utils.ForwardInputs, data_samples: Optional[list] = None, mode=None) List[mmedit.structures.EditDataSample][源代码]

Forward function for SinGAN. For SinGAN, inputs should be a dict contains ‘num_batches’, ‘mode’ and other input arguments for the generator.

参数
  • inputs (dict) – Dict containing the necessary information (e.g., noise, num_batches, mode) to generate image.

  • data_samples (Optional[list]) – Data samples collated by data_preprocessor. Defaults to None.

  • mode (Optional[str]) – mode is not used in BaseConditionalGAN. Defaults to None.

gen_loss(disc_pred_fake: torch.Tensor, recon_imgs: torch.Tensor) Tuple[torch.Tensor, dict][源代码]

Generator loss for SinGAN. SinGAN use WGAN’s loss and MSE loss to train the generator.

参数
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • recon_imgs (Tensor) – Reconstructive images.

返回

Loss value and a dict of log variables.

返回类型

Tuple[Tensor, dict]

disc_loss(disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor, fake_data: torch.Tensor, real_data: torch.Tensor) Tuple[torch.Tensor, dict][源代码]

Get disc loss. SAGAN, SNGAN and Proj-GAN use hinge loss to train the generator.

参数
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • disc_pred_real (Tensor) – Discriminator’s prediction of the real images.

  • fake_data (Tensor) – Generated images, used to calculate gradient penalty.

  • real_data (Tensor) – Real images, used to calculate gradient penalty.

返回

Loss value and a dict of log variables.

返回类型

Tuple[Tensor, dict]

train_generator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][源代码]

Train generator.

参数
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

返回

A dict of tensor for logging.

返回类型

Dict[str, Tensor]

train_discriminator(inputs: dict, data_samples: List[mmedit.structures.EditDataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][源代码]

Train discriminator.

参数
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[EditDataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

返回

A dict of tensor for logging.

返回类型

Dict[str, Tensor]

train_gan(inputs_dict: dict, data_sample: List[mmedit.structures.EditDataSample], optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor][源代码]

Train GAN model. In the training of GAN models, generator and discriminator are updated alternatively. In MMGeneration’s design, self.train_step is called with data input. Therefore we always update discriminator, whose updating is relay on real data, and then determine if the generator needs to be updated based on the current number of iterations. More details about whether to update generator can be found in should_gen_update().

参数
  • data (dict) – Data sampled from dataloader.

  • data_sample (List[EditDataSample]) – List of data sample contains GT and meta information.

  • optim_wrapper (OptimWrapperDict) – OptimWrapperDict instance contains OptimWrapper of generator and discriminator.

返回

A dict of tensor for logging.

返回类型

Dict[str, torch.Tensor]

train_step(data: dict, optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor][源代码]

Train step for SinGAN model. SinGAN is trained with multi-resolution images, and each resolution is trained for :attr:self.iters_per_scale times.

We initialize the weight and learning rate scheduler of the corresponding module at the start of each resolution’s training. At the end of each resolution’s training, we update the weight of the noise of current resolution by mse loss between reconstruced image and real image.

参数
  • data (dict) – Data sampled from dataloader.

  • optim_wrapper (OptimWrapperDict) – OptimWrapperDict instance contains OptimWrapper of generator and discriminator.

返回

A dict of tensor for logging.

返回类型

Dict[str, torch.Tensor]

test_step(data: dict) mmedit.utils.SampleList[源代码]

Gets the generated image of given data in test progress. Before generate images, we call :meth:self.load_test_pkl to load the fixed noise and current stage of the model from the pickle file.

参数

data (dict) – Data sampled from metric specific sampler. More detials in Metrics and Evaluator.

返回

A list of EditDataSample contain generated results.

返回类型

SampleList

Read the Docs v: latest
Versions
master
latest
stable
zyh-doc-notfound-extend
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.