|
|
|
|
@ -6,6 +6,7 @@ import tqdm
|
|
|
|
|
from PIL import Image
|
|
|
|
|
import inspect
|
|
|
|
|
import k_diffusion.sampling
|
|
|
|
|
import torchsde._brownian.brownian_interval
|
|
|
|
|
import ldm.models.diffusion.ddim
|
|
|
|
|
import ldm.models.diffusion.plms
|
|
|
|
|
from modules import prompt_parser, devices, processing, images
|
|
|
|
|
@ -364,7 +365,23 @@ class TorchHijack:
|
|
|
|
|
if noise.shape == x.shape:
|
|
|
|
|
return noise
|
|
|
|
|
|
|
|
|
|
return torch.randn_like(x)
|
|
|
|
|
if x.device.type == 'mps':
|
|
|
|
|
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
|
|
|
|
else:
|
|
|
|
|
return torch.randn_like(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# MPS fix for randn in torchsde
|
|
|
|
|
def torchsde_randn(size, dtype, device, seed):
|
|
|
|
|
if device.type == 'mps':
|
|
|
|
|
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
|
|
|
|
|
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
|
|
|
|
|
else:
|
|
|
|
|
generator = torch.Generator(device).manual_seed(int(seed))
|
|
|
|
|
return torch.randn(size, dtype=dtype, device=device, generator=generator)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class KDiffusionSampler:
|
|
|
|
|
@ -415,8 +432,7 @@ class KDiffusionSampler:
|
|
|
|
|
self.model_wrap.step = 0
|
|
|
|
|
self.eta = p.eta or opts.eta_ancestral
|
|
|
|
|
|
|
|
|
|
if self.sampler_noises is not None:
|
|
|
|
|
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises)
|
|
|
|
|
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
|
|
|
|
|
|
|
|
|
extra_params_kwargs = {}
|
|
|
|
|
for param_name in self.extra_params:
|
|
|
|
|
|