|
|
|
|
@ -1,6 +1,7 @@
|
|
|
|
|
from collections import deque
|
|
|
|
|
import torch
|
|
|
|
|
import inspect
|
|
|
|
|
import einops
|
|
|
|
|
import k_diffusion.sampling
|
|
|
|
|
from modules import prompt_parser, devices, sd_samplers_common
|
|
|
|
|
|
|
|
|
|
@ -40,6 +41,90 @@ sampler_extra_params = {
|
|
|
|
|
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class CFGDenoiserEdit(torch.nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
|
|
|
|
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
|
|
|
|
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
|
|
|
|
negative prompt.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, model):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.inner_model = model
|
|
|
|
|
self.mask = None
|
|
|
|
|
self.nmask = None
|
|
|
|
|
self.init_latent = None
|
|
|
|
|
self.step = 0
|
|
|
|
|
|
|
|
|
|
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, image_cfg_scale):
|
|
|
|
|
denoised_uncond = x_out[-uncond.shape[0]:]
|
|
|
|
|
denoised = torch.clone(denoised_uncond)
|
|
|
|
|
|
|
|
|
|
for i, conds in enumerate(conds_list):
|
|
|
|
|
for cond_index, weight in conds:
|
|
|
|
|
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
|
|
|
|
denoised[i] = out_uncond[cond_index] + cond_scale * (out_cond[cond_index] - out_img_cond[cond_index]) + image_cfg_scale * (out_img_cond[cond_index] - out_uncond[cond_index])
|
|
|
|
|
|
|
|
|
|
return denoised
|
|
|
|
|
|
|
|
|
|
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond, image_cfg_scale):
|
|
|
|
|
if state.interrupted or state.skipped:
|
|
|
|
|
raise sd_samplers_common.InterruptedException
|
|
|
|
|
|
|
|
|
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
|
|
|
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
|
|
|
|
|
|
|
|
|
batch_size = len(conds_list)
|
|
|
|
|
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
|
|
|
|
|
|
|
|
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
|
|
|
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
|
|
|
|
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
|
|
|
|
|
|
|
|
|
|
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
|
|
|
|
|
cfg_denoiser_callback(denoiser_params)
|
|
|
|
|
x_in = denoiser_params.x
|
|
|
|
|
image_cond_in = denoiser_params.image_cond
|
|
|
|
|
sigma_in = denoiser_params.sigma
|
|
|
|
|
|
|
|
|
|
if tensor.shape[1] == uncond.shape[1]:
|
|
|
|
|
cond_in = torch.cat([tensor, uncond, uncond])
|
|
|
|
|
|
|
|
|
|
if shared.batch_cond_uncond:
|
|
|
|
|
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
|
|
|
|
|
else:
|
|
|
|
|
x_out = torch.zeros_like(x_in)
|
|
|
|
|
for batch_offset in range(0, x_out.shape[0], batch_size):
|
|
|
|
|
a = batch_offset
|
|
|
|
|
b = a + batch_size
|
|
|
|
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
|
|
|
|
|
else:
|
|
|
|
|
x_out = torch.zeros_like(x_in)
|
|
|
|
|
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
|
|
|
|
for batch_offset in range(0, tensor.shape[0], batch_size):
|
|
|
|
|
a = batch_offset
|
|
|
|
|
b = min(a + batch_size, tensor.shape[0])
|
|
|
|
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": torch.cat([tensor[a:b]], uncond) , "c_concat": [image_cond_in[a:b]]})
|
|
|
|
|
|
|
|
|
|
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
|
|
|
|
|
|
|
|
|
devices.test_for_nans(x_out, "unet")
|
|
|
|
|
|
|
|
|
|
if opts.live_preview_content == "Prompt":
|
|
|
|
|
sd_samplers_common.store_latent(x_out[0:uncond.shape[0]])
|
|
|
|
|
elif opts.live_preview_content == "Negative prompt":
|
|
|
|
|
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
|
|
|
|
|
|
|
|
|
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale, image_cfg_scale)
|
|
|
|
|
|
|
|
|
|
if self.mask is not None:
|
|
|
|
|
denoised = self.init_latent * self.mask + self.nmask * denoised
|
|
|
|
|
|
|
|
|
|
self.step += 1
|
|
|
|
|
|
|
|
|
|
return denoised
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CFGDenoiser(torch.nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
@ -78,8 +163,8 @@ class CFGDenoiser(torch.nn.Module):
|
|
|
|
|
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
|
|
|
|
|
|
|
|
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
|
|
|
|
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
|
|
|
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
|
|
|
|
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
|
|
|
|
|
|
|
|
|
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
|
|
|
|
|
cfg_denoiser_callback(denoiser_params)
|
|
|
|
|
@ -195,6 +280,9 @@ class KDiffusionSampler:
|
|
|
|
|
return p.steps
|
|
|
|
|
|
|
|
|
|
def initialize(self, p):
|
|
|
|
|
if shared.sd_model.cond_stage_key == "edit" and getattr(p, 'image_cfg_scale', None) != 1:
|
|
|
|
|
self.model_wrap_cfg = CFGDenoiserEdit(self.model_wrap)
|
|
|
|
|
|
|
|
|
|
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
|
|
|
|
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
|
|
|
|
self.model_wrap_cfg.step = 0
|
|
|
|
|
@ -260,13 +348,17 @@ class KDiffusionSampler:
|
|
|
|
|
|
|
|
|
|
self.model_wrap_cfg.init_latent = x
|
|
|
|
|
self.last_latent = x
|
|
|
|
|
|
|
|
|
|
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
|
|
|
|
|
extra_args={
|
|
|
|
|
'cond': conditioning,
|
|
|
|
|
'image_cond': image_conditioning,
|
|
|
|
|
'uncond': unconditional_conditioning,
|
|
|
|
|
'cond_scale': p.cfg_scale
|
|
|
|
|
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
|
|
|
|
'cond_scale': p.cfg_scale,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if hasattr(p, 'image_cfg_scale') and p.image_cfg_scale != 1 and p.image_cfg_scale != None:
|
|
|
|
|
extra_args['image_cfg_scale'] = p.image_cfg_scale
|
|
|
|
|
|
|
|
|
|
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
|
|
|
|
|
|
|
|
|
return samples
|
|
|
|
|
|
|
|
|
|
|