Add safetensors support to LDSR

master
wywywywy 3 years ago
parent 685f9631b5
commit 8bcdd50461

@ -1,3 +1,4 @@
import os
import gc import gc
import time import time
import warnings import warnings
@ -8,6 +9,7 @@ import torchvision
from PIL import Image from PIL import Image
from einops import rearrange, repeat from einops import rearrange, repeat
from omegaconf import OmegaConf from omegaconf import OmegaConf
import safetensors.torch
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap from ldm.util import instantiate_from_config, ismap
@ -28,8 +30,12 @@ class LDSR:
model: torch.nn.Module = cached_ldsr_model model: torch.nn.Module = cached_ldsr_model
else: else:
print(f"Loading model from {self.modelPath}") print(f"Loading model from {self.modelPath}")
_, extension = os.path.splitext(self.modelPath)
if extension.lower() == ".safetensors":
pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")
else:
pl_sd = torch.load(self.modelPath, map_location="cpu") pl_sd = torch.load(self.modelPath, map_location="cpu")
sd = pl_sd["state_dict"] sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
config = OmegaConf.load(self.yamlPath) config = OmegaConf.load(self.yamlPath)
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1" config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
model: torch.nn.Module = instantiate_from_config(config.model) model: torch.nn.Module = instantiate_from_config(config.model)

@ -25,6 +25,7 @@ class UpscalerLDSR(Upscaler):
yaml_path = os.path.join(self.model_path, "project.yaml") yaml_path = os.path.join(self.model_path, "project.yaml")
old_model_path = os.path.join(self.model_path, "model.pth") old_model_path = os.path.join(self.model_path, "model.pth")
new_model_path = os.path.join(self.model_path, "model.ckpt") new_model_path = os.path.join(self.model_path, "model.ckpt")
safetensors_model_path = os.path.join(self.model_path, "model.safetensors")
if os.path.exists(yaml_path): if os.path.exists(yaml_path):
statinfo = os.stat(yaml_path) statinfo = os.stat(yaml_path)
if statinfo.st_size >= 10485760: if statinfo.st_size >= 10485760:
@ -33,6 +34,9 @@ class UpscalerLDSR(Upscaler):
if os.path.exists(old_model_path): if os.path.exists(old_model_path):
print("Renaming model from model.pth to model.ckpt") print("Renaming model from model.pth to model.ckpt")
os.rename(old_model_path, new_model_path) os.rename(old_model_path, new_model_path)
if os.path.exists(safetensors_model_path):
model = safetensors_model_path
else:
model = load_file_from_url(url=self.model_url, model_dir=self.model_path, model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
file_name="model.ckpt", progress=True) file_name="model.ckpt", progress=True)
yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path, yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,

Loading…
Cancel
Save