|
|
|
|
@ -1,5 +1,6 @@
|
|
|
|
|
import torch
|
|
|
|
|
import os
|
|
|
|
|
import collections
|
|
|
|
|
from collections import namedtuple
|
|
|
|
|
from modules import shared, devices, script_callbacks
|
|
|
|
|
from modules.paths import models_path
|
|
|
|
|
@ -30,6 +31,7 @@ base_vae = None
|
|
|
|
|
loaded_vae_file = None
|
|
|
|
|
checkpoint_info = None
|
|
|
|
|
|
|
|
|
|
checkpoints_loaded = collections.OrderedDict()
|
|
|
|
|
|
|
|
|
|
def get_base_vae(model):
|
|
|
|
|
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
|
|
|
|
@ -149,13 +151,30 @@ def load_vae(model, vae_file=None):
|
|
|
|
|
global first_load, vae_dict, vae_list, loaded_vae_file
|
|
|
|
|
# save_settings = False
|
|
|
|
|
|
|
|
|
|
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
|
|
|
|
|
|
|
|
|
if vae_file:
|
|
|
|
|
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
|
|
|
|
|
print(f"Loading VAE weights from: {vae_file}")
|
|
|
|
|
store_base_vae(model)
|
|
|
|
|
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
|
|
|
|
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
|
|
|
|
_load_vae_dict(model, vae_dict_1)
|
|
|
|
|
if cache_enabled and vae_file in checkpoints_loaded:
|
|
|
|
|
# use vae checkpoint cache
|
|
|
|
|
print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
|
|
|
|
|
store_base_vae(model)
|
|
|
|
|
_load_vae_dict(model, checkpoints_loaded[vae_file])
|
|
|
|
|
else:
|
|
|
|
|
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
|
|
|
|
|
print(f"Loading VAE weights from: {vae_file}")
|
|
|
|
|
store_base_vae(model)
|
|
|
|
|
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
|
|
|
|
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
|
|
|
|
_load_vae_dict(model, vae_dict_1)
|
|
|
|
|
|
|
|
|
|
if cache_enabled:
|
|
|
|
|
# cache newly loaded vae
|
|
|
|
|
checkpoints_loaded[vae_file] = vae_dict_1.copy()
|
|
|
|
|
|
|
|
|
|
# clean up cache if limit is reached
|
|
|
|
|
if cache_enabled:
|
|
|
|
|
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
|
|
|
|
|
checkpoints_loaded.popitem(last=False) # LRU
|
|
|
|
|
|
|
|
|
|
# If vae used is not in dict, update it
|
|
|
|
|
# It will be removed on refresh though
|
|
|
|
|
|