|
|
|
|
@ -378,7 +378,8 @@ class UniPC:
|
|
|
|
|
condition=None,
|
|
|
|
|
unconditional_condition=None,
|
|
|
|
|
before_sample=None,
|
|
|
|
|
after_sample=None
|
|
|
|
|
after_sample=None,
|
|
|
|
|
after_update=None
|
|
|
|
|
):
|
|
|
|
|
"""Construct a UniPC.
|
|
|
|
|
|
|
|
|
|
@ -394,6 +395,7 @@ class UniPC:
|
|
|
|
|
self.unconditional_condition = unconditional_condition
|
|
|
|
|
self.before_sample = before_sample
|
|
|
|
|
self.after_sample = after_sample
|
|
|
|
|
self.after_update = after_update
|
|
|
|
|
|
|
|
|
|
def dynamic_thresholding_fn(self, x0, t=None):
|
|
|
|
|
"""
|
|
|
|
|
@ -434,15 +436,6 @@ class UniPC:
|
|
|
|
|
noise = self.noise_prediction_fn(x, t)
|
|
|
|
|
dims = x.dim()
|
|
|
|
|
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
|
|
|
|
from pprint import pp
|
|
|
|
|
print("X:")
|
|
|
|
|
pp(x)
|
|
|
|
|
print("sigma_t:")
|
|
|
|
|
pp(sigma_t)
|
|
|
|
|
print("noise:")
|
|
|
|
|
pp(noise)
|
|
|
|
|
print("alpha_t:")
|
|
|
|
|
pp(alpha_t)
|
|
|
|
|
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
|
|
|
|
|
if self.thresholding:
|
|
|
|
|
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
|
|
|
|
|
@ -524,7 +517,7 @@ class UniPC:
|
|
|
|
|
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
|
|
|
|
|
|
|
|
|
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
|
|
|
|
print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
|
|
|
|
#print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
|
|
|
|
ns = self.noise_schedule
|
|
|
|
|
assert order <= len(model_prev_list)
|
|
|
|
|
|
|
|
|
|
@ -568,7 +561,7 @@ class UniPC:
|
|
|
|
|
A_p = C_inv_p
|
|
|
|
|
|
|
|
|
|
if use_corrector:
|
|
|
|
|
print('using corrector')
|
|
|
|
|
#print('using corrector')
|
|
|
|
|
C_inv = torch.linalg.inv(C)
|
|
|
|
|
A_c = C_inv
|
|
|
|
|
|
|
|
|
|
@ -627,7 +620,7 @@ class UniPC:
|
|
|
|
|
return x_t, model_t
|
|
|
|
|
|
|
|
|
|
def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
|
|
|
|
|
print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
|
|
|
|
|
#print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
|
|
|
|
|
ns = self.noise_schedule
|
|
|
|
|
assert order <= len(model_prev_list)
|
|
|
|
|
dims = x.dim()
|
|
|
|
|
@ -695,7 +688,7 @@ class UniPC:
|
|
|
|
|
D1s = None
|
|
|
|
|
|
|
|
|
|
if use_corrector:
|
|
|
|
|
print('using corrector')
|
|
|
|
|
#print('using corrector')
|
|
|
|
|
# for order 1, we use a simplified version
|
|
|
|
|
if order == 1:
|
|
|
|
|
rhos_c = torch.tensor([0.5], device=b.device)
|
|
|
|
|
@ -755,8 +748,9 @@ class UniPC:
|
|
|
|
|
t_T = self.noise_schedule.T if t_start is None else t_start
|
|
|
|
|
device = x.device
|
|
|
|
|
if method == 'multistep':
|
|
|
|
|
assert steps >= order
|
|
|
|
|
assert steps >= order, "UniPC order must be < sampling steps"
|
|
|
|
|
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
|
|
|
|
print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps")
|
|
|
|
|
assert timesteps.shape[0] - 1 == steps
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
vec_t = timesteps[0].expand((x.shape[0]))
|
|
|
|
|
@ -768,6 +762,8 @@ class UniPC:
|
|
|
|
|
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
|
|
|
|
if model_x is None:
|
|
|
|
|
model_x = self.model_fn(x, vec_t)
|
|
|
|
|
if self.after_update is not None:
|
|
|
|
|
self.after_update(x, model_x)
|
|
|
|
|
model_prev_list.append(model_x)
|
|
|
|
|
t_prev_list.append(vec_t)
|
|
|
|
|
for step in range(order, steps + 1):
|
|
|
|
|
@ -776,13 +772,15 @@ class UniPC:
|
|
|
|
|
step_order = min(order, steps + 1 - step)
|
|
|
|
|
else:
|
|
|
|
|
step_order = order
|
|
|
|
|
print('this step order:', step_order)
|
|
|
|
|
#print('this step order:', step_order)
|
|
|
|
|
if step == steps:
|
|
|
|
|
print('do not run corrector at the last step')
|
|
|
|
|
#print('do not run corrector at the last step')
|
|
|
|
|
use_corrector = False
|
|
|
|
|
else:
|
|
|
|
|
use_corrector = True
|
|
|
|
|
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
|
|
|
|
if self.after_update is not None:
|
|
|
|
|
self.after_update(x, model_x)
|
|
|
|
|
for i in range(order - 1):
|
|
|
|
|
t_prev_list[i] = t_prev_list[i + 1]
|
|
|
|
|
model_prev_list[i] = model_prev_list[i + 1]
|
|
|
|
|
|