add safetensors to requirements

master
AUTOMATIC 3 years ago
parent f108782e30
commit 6074175faa

@ -5,6 +5,7 @@ import gc
from collections import namedtuple from collections import namedtuple
import torch import torch
import re import re
import safetensors.torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
@ -173,14 +174,12 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
# load from file # load from file
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
if checkpoint_file.endswith(".safetensors"): _, extension = os.path.splitext(checkpoint_file)
try: if extension.lower() == ".safetensors":
from safetensors.torch import load_file pl_sd = safetensors.torch.load_file(checkpoint_file, device=shared.weight_load_location)
except ImportError as e:
raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}")
pl_sd = load_file(checkpoint_file, device=shared.weight_load_location)
else: else:
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
if "global_step" in pl_sd: if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}") print(f"Global Step: {pl_sd['global_step']}")

@ -29,3 +29,4 @@ lark
inflection inflection
GitPython GitPython
torchsde torchsde
safetensors

@ -26,3 +26,4 @@ lark==1.1.2
inflection==0.5.1 inflection==0.5.1
GitPython==3.1.27 GitPython==3.1.27
torchsde==0.2.5 torchsde==0.2.5
safetensors==0.2.5

Loading…
Cancel
Save