Shortcuts

mmedit.models.editors.stable_diffusion

Package Contents

Classes

StableDiffusion

class to run stable diffsuion pipeline.

class mmedit.models.editors.stable_diffusion.StableDiffusion(diffusion_scheduler, unet, vae, requires_safety_checker=True, unet_sample_size=64, init_cfg=None)[source]

Bases: mmengine.model.BaseModel

class to run stable diffsuion pipeline.

Parameters
  • diffusion_scheduler (dict) – Diffusion scheduler config.

  • unet_cfg (dict) – Unet config.

  • vae_cfg (dict) – Vae config.

  • pretrained_ckpt_path (dict) – Pretrained ckpt path for submodels in stable diffusion.

  • requires_safety_checker (bool) – whether to run safety checker after image generated.

  • unet_sample_size (int) – sampel size for unet.

init_weights()[source]

load pretrained ckpt for each submodel.

to(torch_device: Optional[Union[str, torch.device]] = None)[source]

put submodels to torch device.

Parameters

torch_device (Optional[Union[str, torch.device]]) – device to put, default to None.

Returns

class instance itsself.

Return type

self(StableDiffusion)

infer(prompt: Union[str, List[str]], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, show_progress=True, seed=1)[source]

Function invoked when calling the pipeline for generation.

Parameters
  • prompt (str or List[str]) – The prompt or prompts to guide the image generation.

  • (int (height) – defaults to self.unet_sample_size * self.vae_scale_factor): The height in pixels of the generated image.

  • optional – defaults to self.unet_sample_size * self.vae_scale_factor): The height in pixels of the generated image.

:paramdefaults to self.unet_sample_size * self.vae_scale_factor):

The height in pixels of the generated image.

Parameters
  • (int (width) – defaults to self.unet_sample_size * self.vae_scale_factor): The width in pixels of the generated image.

  • optional – defaults to self.unet_sample_size * self.vae_scale_factor): The width in pixels of the generated image.

:paramdefaults to self.unet_sample_size * self.vae_scale_factor):

The width in pixels of the generated image.

Parameters
  • num_inference_steps (int, optional, defaults to 50) – The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.

  • guidance_scale (float, optional, defaults to 7.5) – Guidance scale as defined in [Classifier-Free Diffusion Guidance] (https://arxiv.org/abs/2207.12598).

  • negative_prompt (str or List[str], optional) – The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if guidance_scale is less than 1).

  • num_images_per_prompt (int, optional, defaults to 1) – The number of images to generate per prompt.

  • eta (float, optional, defaults to 0.0) – Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [schedulers.DDIMScheduler], will be ignored for others.

  • generator (torch.Generator, optional) – A [torch generator] to make generation deterministic.

  • latents (torch.FloatTensor, optional) – Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random generator.

Returns

[‘samples’, ‘nsfw_content_detected’]:

’samples’: image result samples ‘nsfw_content_detected’: nsfw content flags for image samples.

Return type

dict

_encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt)[source]

Encodes the prompt into text encoder hidden states.

Parameters
  • prompt (str or list(int)) – prompt to be encoded.

  • device – (torch.device): torch device.

  • num_images_per_prompt (int) – number of images that should be generated per prompt.

  • do_classifier_free_guidance (bool) – whether to use classifier free guidance or not.

  • negative_prompt (str or List[str]) – The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if guidance_scale is less than 1).

Returns

text embeddings generated by clip text encoder.

Return type

text_embeddings (torch.Tensor)

run_safety_checker(image, device, dtype)[source]

run safety checker to check whether image has nsfw content.

Parameters
  • image (numpy.ndarray) – image generated by stable diffusion.

  • device (torch.device) – device to run safety checker.

  • dtype (torch.dtype) – float type to run.

Returns

black image if nsfw content detected else input image. has_nsfw_concept (list[bool]):

flag list to indicate nsfw content detected.

Return type

image (numpy.ndarray)

decode_latents(latents)[source]

use vae to decode latents.

Parameters

latents (torch.Tensor) – latents to decode.

Returns

image result.

Return type

image (numpy.ndarray)

prepare_extra_step_kwargs(generator, eta)[source]

prepare extra kwargs for the scheduler step.

Parameters
  • generator (torch.Generator) – generator for random functions.

  • eta (float) – eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 and should be between [0, 1]

Returns

dict contains ‘generator’ and ‘eta’

Return type

extra_step_kwargs (dict)

check_inputs(prompt, height, width)[source]

check whether inputs are in suitable format or not.

prepare_latents(batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None)[source]

prepare latents for diffusion to run in latent space.

Parameters
  • batch_size (int) – batch size.

  • num_channels_latents (int) – latent channel nums.

  • height (int) – image height.

  • width (int) – image width.

  • dtype (torch.dtype) – float type.

  • device (torch.device) – torch device.

  • generator (torch.Generator) – generator for random functions, defaults to None.

  • latents (torch.Tensor) – Pre-generated noisy latents, defaults to None.

Returns

prepared latents.

Return type

latents (torch.Tensor)

abstract forward(inputs: torch.Tensor, data_samples: Optional[list] = None, mode: str = 'tensor') Union[Dict[str, torch.Tensor], list][source]

forward is not implemented now.

Read the Docs v: latest
Versions
master
latest
stable
zyh-re-docs
zyh-doc-notfound-extend
zyh-api-rendering
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.