|
|
|
@ -1,5 +1,8 @@
|
|
|
|
import glob
|
|
|
|
import glob
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
|
|
|
|
import sys
|
|
|
|
|
|
|
|
import traceback
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
from modules import devices
|
|
|
|
from modules import devices
|
|
|
|
|
|
|
|
|
|
|
|
@ -36,8 +39,12 @@ def load_hypernetworks(path):
|
|
|
|
res = {}
|
|
|
|
res = {}
|
|
|
|
|
|
|
|
|
|
|
|
for filename in glob.iglob(path + '**/*.pt', recursive=True):
|
|
|
|
for filename in glob.iglob(path + '**/*.pt', recursive=True):
|
|
|
|
hn = Hypernetwork(filename)
|
|
|
|
try:
|
|
|
|
res[hn.name] = hn
|
|
|
|
hn = Hypernetwork(filename)
|
|
|
|
|
|
|
|
res[hn.name] = hn
|
|
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
|
|
print(f"Error loading hypernetwork {filename}", file=sys.stderr)
|
|
|
|
|
|
|
|
print(traceback.format_exc(), file=sys.stderr)
|
|
|
|
|
|
|
|
|
|
|
|
return res
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
|