|
|
|
@ -20,7 +20,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|
|
|
|
|
|
|
|
|
|
|
q = self.to_q(x)
|
|
|
|
q = self.to_q(x)
|
|
|
|
context = default(context, x)
|
|
|
|
context = default(context, x)
|
|
|
|
k = self.to_k(context) * self.scale
|
|
|
|
k = self.to_k(context)
|
|
|
|
v = self.to_v(context)
|
|
|
|
v = self.to_v(context)
|
|
|
|
del context, x
|
|
|
|
del context, x
|
|
|
|
|
|
|
|
|
|
|
|
@ -50,7 +50,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
|
|
|
|
|
|
|
|
|
|
|
q_in = self.to_q(x)
|
|
|
|
q_in = self.to_q(x)
|
|
|
|
context = default(context, x)
|
|
|
|
context = default(context, x)
|
|
|
|
k_in = self.to_k(context)
|
|
|
|
k_in = self.to_k(context) * self.scale
|
|
|
|
v_in = self.to_v(context)
|
|
|
|
v_in = self.to_v(context)
|
|
|
|
del context, x
|
|
|
|
del context, x
|
|
|
|
|
|
|
|
|
|
|
|
|