|
|
|
|
@ -12,6 +12,7 @@ from ldm.util import default
|
|
|
|
|
from einops import rearrange
|
|
|
|
|
import ldm.modules.attention
|
|
|
|
|
import ldm.modules.diffusionmodules.model
|
|
|
|
|
from torch.nn.functional import silu
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
|
|
|
|
@ -100,14 +101,6 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
|
|
|
|
|
|
|
|
|
return self.to_out(r2)
|
|
|
|
|
|
|
|
|
|
def nonlinearity_hijack(x):
|
|
|
|
|
# swish
|
|
|
|
|
t = torch.sigmoid(x)
|
|
|
|
|
x *= t
|
|
|
|
|
del t
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def cross_attention_attnblock_forward(self, x):
|
|
|
|
|
h_ = x
|
|
|
|
|
h_ = self.norm(h_)
|
|
|
|
|
@ -245,11 +238,12 @@ class StableDiffusionModelHijack:
|
|
|
|
|
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
|
|
|
|
self.clip = m.cond_stage_model
|
|
|
|
|
|
|
|
|
|
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
|
|
|
|
|
|
|
|
|
if cmd_opts.opt_split_attention_v1:
|
|
|
|
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
|
|
|
|
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
|
|
|
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
|
|
|
|
ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
|
|
|
|
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
|
|
|
|
|
|
|
|
|
def flatten(el):
|
|
|
|
|
|