|
|
|
|
@ -3,7 +3,8 @@
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
|
|
|
|
|
from modules import shared
|
|
|
|
|
from modules import shared, devices
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UniPCSampler(object):
|
|
|
|
|
def __init__(self, model, **kwargs):
|
|
|
|
|
@ -16,8 +17,8 @@ class UniPCSampler(object):
|
|
|
|
|
|
|
|
|
|
def register_buffer(self, name, attr):
|
|
|
|
|
if type(attr) == torch.Tensor:
|
|
|
|
|
if attr.device != torch.device("cuda"):
|
|
|
|
|
attr = attr.to(torch.device("cuda"))
|
|
|
|
|
if attr.device != devices.device:
|
|
|
|
|
attr = attr.to(devices.device)
|
|
|
|
|
setattr(self, name, attr)
|
|
|
|
|
|
|
|
|
|
def set_hooks(self, before_sample, after_sample, after_update):
|
|
|
|
|
|