|
|
|
|
@ -10,6 +10,7 @@ from torch.nn.functional import silu
|
|
|
|
|
import modules.textual_inversion.textual_inversion
|
|
|
|
|
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
|
|
|
|
|
from modules.shared import opts, device, cmd_opts
|
|
|
|
|
from modules.sd_hijack_optimizations import invokeAI_mps_available
|
|
|
|
|
|
|
|
|
|
import ldm.modules.attention
|
|
|
|
|
import ldm.modules.diffusionmodules.model
|
|
|
|
|
@ -31,8 +32,13 @@ def apply_optimizations():
|
|
|
|
|
print("Applying v1 cross attention optimization.")
|
|
|
|
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
|
|
|
|
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
|
|
|
|
|
print("Applying cross attention optimization (InvokeAI).")
|
|
|
|
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
|
|
|
|
|
if not invokeAI_mps_available and shared.device.type == 'mps':
|
|
|
|
|
print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
|
|
|
|
|
print("Applying v1 cross attention optimization.")
|
|
|
|
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
|
|
|
|
else:
|
|
|
|
|
print("Applying cross attention optimization (InvokeAI).")
|
|
|
|
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
|
|
|
|
|
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
|
|
|
|
print("Applying cross attention optimization (Doggettx).")
|
|
|
|
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
|
|
|
|
|