Merge branch 'master' into saving
commit
a9d7eb722f
@ -1 +0,0 @@
|
|||||||
|
|
||||||
|
@ -0,0 +1,8 @@
|
|||||||
|
|
||||||
|
|
||||||
|
function start_training_textual_inversion(){
|
||||||
|
requestProgress('ti')
|
||||||
|
gradioApp().querySelector('#ti_error').innerHTML=''
|
||||||
|
|
||||||
|
return args_to_array(arguments)
|
||||||
|
}
|
||||||
@ -0,0 +1,78 @@
|
|||||||
|
import os.path
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
|
import modules.upscaler
|
||||||
|
from modules import shared, modelloader
|
||||||
|
from modules.bsrgan_model_arch import RRDBNet
|
||||||
|
from modules.paths import models_path
|
||||||
|
|
||||||
|
|
||||||
|
class UpscalerBSRGAN(modules.upscaler.Upscaler):
|
||||||
|
def __init__(self, dirname):
|
||||||
|
self.name = "BSRGAN"
|
||||||
|
self.model_path = os.path.join(models_path, self.name)
|
||||||
|
self.model_name = "BSRGAN 4x"
|
||||||
|
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
|
||||||
|
self.user_path = dirname
|
||||||
|
super().__init__()
|
||||||
|
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
||||||
|
scalers = []
|
||||||
|
if len(model_paths) == 0:
|
||||||
|
scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
|
||||||
|
scalers.append(scaler_data)
|
||||||
|
for file in model_paths:
|
||||||
|
if "http" in file:
|
||||||
|
name = self.model_name
|
||||||
|
else:
|
||||||
|
name = modelloader.friendly_name(file)
|
||||||
|
try:
|
||||||
|
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
||||||
|
scalers.append(scaler_data)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
self.scalers = scalers
|
||||||
|
|
||||||
|
def do_upscale(self, img: PIL.Image, selected_file):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
model = self.load_model(selected_file)
|
||||||
|
if model is None:
|
||||||
|
return img
|
||||||
|
model.to(shared.device)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
img = np.array(img)
|
||||||
|
img = img[:, :, ::-1]
|
||||||
|
img = np.moveaxis(img, 2, 0) / 255
|
||||||
|
img = torch.from_numpy(img).float()
|
||||||
|
img = img.unsqueeze(0).to(shared.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(img)
|
||||||
|
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
|
output = 255. * np.moveaxis(output, 0, 2)
|
||||||
|
output = output.astype(np.uint8)
|
||||||
|
output = output[:, :, ::-1]
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return PIL.Image.fromarray(output, 'RGB')
|
||||||
|
|
||||||
|
def load_model(self, path: str):
|
||||||
|
if "http" in path:
|
||||||
|
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
||||||
|
progress=True)
|
||||||
|
else:
|
||||||
|
filename = path
|
||||||
|
if not os.path.exists(filename) or filename is None:
|
||||||
|
print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
|
||||||
|
return None
|
||||||
|
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
|
||||||
|
model.load_state_dict(torch.load(filename), strict=True)
|
||||||
|
model.eval()
|
||||||
|
for k, v in model.named_parameters():
|
||||||
|
v.requires_grad = False
|
||||||
|
return model
|
||||||
|
|
||||||
@ -0,0 +1,102 @@
|
|||||||
|
import functools
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.nn.init as init
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_weights(net_l, scale=1):
|
||||||
|
if not isinstance(net_l, list):
|
||||||
|
net_l = [net_l]
|
||||||
|
for net in net_l:
|
||||||
|
for m in net.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||||
|
m.weight.data *= scale # for residual block
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||||
|
m.weight.data *= scale
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
init.constant_(m.weight, 1)
|
||||||
|
init.constant_(m.bias.data, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def make_layer(block, n_layers):
|
||||||
|
layers = []
|
||||||
|
for _ in range(n_layers):
|
||||||
|
layers.append(block())
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualDenseBlock_5C(nn.Module):
|
||||||
|
def __init__(self, nf=64, gc=32, bias=True):
|
||||||
|
super(ResidualDenseBlock_5C, self).__init__()
|
||||||
|
# gc: growth channel, i.e. intermediate channels
|
||||||
|
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||||
|
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||||
|
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||||
|
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||||
|
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
|
# initialization
|
||||||
|
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x1 = self.lrelu(self.conv1(x))
|
||||||
|
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||||
|
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||||
|
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||||
|
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||||
|
return x5 * 0.2 + x
|
||||||
|
|
||||||
|
|
||||||
|
class RRDB(nn.Module):
|
||||||
|
'''Residual in Residual Dense Block'''
|
||||||
|
|
||||||
|
def __init__(self, nf, gc=32):
|
||||||
|
super(RRDB, self).__init__()
|
||||||
|
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||||
|
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||||
|
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.RDB1(x)
|
||||||
|
out = self.RDB2(out)
|
||||||
|
out = self.RDB3(out)
|
||||||
|
return out * 0.2 + x
|
||||||
|
|
||||||
|
|
||||||
|
class RRDBNet(nn.Module):
|
||||||
|
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
|
||||||
|
super(RRDBNet, self).__init__()
|
||||||
|
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||||
|
self.sf = sf
|
||||||
|
|
||||||
|
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||||
|
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||||
|
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
#### upsampling
|
||||||
|
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
if self.sf==4:
|
||||||
|
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
fea = self.conv_first(x)
|
||||||
|
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||||
|
fea = fea + trunk
|
||||||
|
|
||||||
|
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||||
|
if self.sf==4:
|
||||||
|
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||||
|
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||||
|
|
||||||
|
return out
|
||||||
@ -1,67 +1,56 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
import modules.images
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
from modules.ldsr_model_arch import LDSR
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.paths import script_path
|
from modules.paths import models_path
|
||||||
|
|
||||||
LDSRModelInfo = namedtuple("LDSRModelInfo", ["name", "location", "model", "netscale"])
|
|
||||||
|
|
||||||
ldsr_models = []
|
class UpscalerLDSR(Upscaler):
|
||||||
have_ldsr = False
|
def __init__(self, user_path):
|
||||||
LDSR_obj = None
|
|
||||||
|
|
||||||
|
|
||||||
class UpscalerLDSR(modules.images.Upscaler):
|
|
||||||
def __init__(self, steps):
|
|
||||||
self.steps = steps
|
|
||||||
self.name = "LDSR"
|
self.name = "LDSR"
|
||||||
|
self.model_path = os.path.join(models_path, self.name)
|
||||||
def do_upscale(self, img):
|
self.user_path = user_path
|
||||||
return upscale_with_ldsr(img)
|
self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
||||||
|
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
||||||
|
super().__init__()
|
||||||
def add_lsdr():
|
scaler_data = UpscalerData("LDSR", None, self)
|
||||||
modules.shared.sd_upscalers.append(UpscalerLDSR(100))
|
self.scalers = [scaler_data]
|
||||||
|
|
||||||
|
def load_model(self, path: str):
|
||||||
def setup_ldsr():
|
# Remove incorrect project.yaml file if too big
|
||||||
path = modules.paths.paths.get("LDSR", None)
|
yaml_path = os.path.join(self.model_path, "project.yaml")
|
||||||
if path is None:
|
old_model_path = os.path.join(self.model_path, "model.pth")
|
||||||
return
|
new_model_path = os.path.join(self.model_path, "model.ckpt")
|
||||||
global have_ldsr
|
if os.path.exists(yaml_path):
|
||||||
global LDSR_obj
|
statinfo = os.stat(yaml_path)
|
||||||
try:
|
if statinfo.st_size >= 10485760:
|
||||||
from LDSR import LDSR
|
print("Removing invalid LDSR YAML file.")
|
||||||
model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
os.remove(yaml_path)
|
||||||
yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
if os.path.exists(old_model_path):
|
||||||
repo_path = 'latent-diffusion/experiments/pretrained_models/'
|
print("Renaming model from model.pth to model.ckpt")
|
||||||
model_path = load_file_from_url(url=model_url, model_dir=os.path.join("repositories", repo_path),
|
os.rename(old_model_path, new_model_path)
|
||||||
progress=True, file_name="model.chkpt")
|
model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
||||||
yaml_path = load_file_from_url(url=yaml_url, model_dir=os.path.join("repositories", repo_path),
|
file_name="model.ckpt", progress=True)
|
||||||
progress=True, file_name="project.yaml")
|
yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
|
||||||
have_ldsr = True
|
file_name="project.yaml", progress=True)
|
||||||
LDSR_obj = LDSR(model_path, yaml_path)
|
|
||||||
|
try:
|
||||||
|
return LDSR(model, yaml)
|
||||||
except Exception:
|
|
||||||
print("Error importing LDSR:", file=sys.stderr)
|
except Exception:
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print("Error importing LDSR:", file=sys.stderr)
|
||||||
have_ldsr = False
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
return None
|
||||||
|
|
||||||
def upscale_with_ldsr(image):
|
def do_upscale(self, img, path):
|
||||||
setup_ldsr()
|
ldsr = self.load_model(path)
|
||||||
if not have_ldsr or LDSR_obj is None:
|
if ldsr is None:
|
||||||
return image
|
print("NO LDSR!")
|
||||||
|
return img
|
||||||
ddim_steps = shared.opts.ldsr_steps
|
ddim_steps = shared.opts.ldsr_steps
|
||||||
pre_scale = shared.opts.ldsr_pre_down
|
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
||||||
post_scale = shared.opts.ldsr_post_down
|
|
||||||
|
|
||||||
image = LDSR_obj.super_resolution(image, ddim_steps, pre_scale, post_scale)
|
|
||||||
return image
|
|
||||||
|
|||||||
@ -0,0 +1,222 @@
|
|||||||
|
import gc
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from PIL import Image
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.util import instantiate_from_config, ismap
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
|
||||||
|
|
||||||
|
# Create LDSR Class
|
||||||
|
class LDSR:
|
||||||
|
def load_model_from_config(self, half_attention):
|
||||||
|
print(f"Loading model from {self.modelPath}")
|
||||||
|
pl_sd = torch.load(self.modelPath, map_location="cpu")
|
||||||
|
sd = pl_sd["state_dict"]
|
||||||
|
config = OmegaConf.load(self.yamlPath)
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
model.load_state_dict(sd, strict=False)
|
||||||
|
model.cuda()
|
||||||
|
if half_attention:
|
||||||
|
model = model.half()
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
return {"model": model}
|
||||||
|
|
||||||
|
def __init__(self, model_path, yaml_path):
|
||||||
|
self.modelPath = model_path
|
||||||
|
self.yamlPath = yaml_path
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def run(model, selected_path, custom_steps, eta):
|
||||||
|
example = get_cond(selected_path)
|
||||||
|
|
||||||
|
n_runs = 1
|
||||||
|
guider = None
|
||||||
|
ckwargs = None
|
||||||
|
ddim_use_x0_pred = False
|
||||||
|
temperature = 1.
|
||||||
|
eta = eta
|
||||||
|
custom_shape = None
|
||||||
|
|
||||||
|
height, width = example["image"].shape[1:3]
|
||||||
|
split_input = height >= 128 and width >= 128
|
||||||
|
|
||||||
|
if split_input:
|
||||||
|
ks = 128
|
||||||
|
stride = 64
|
||||||
|
vqf = 4 #
|
||||||
|
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
|
||||||
|
"vqf": vqf,
|
||||||
|
"patch_distributed_vq": True,
|
||||||
|
"tie_braker": False,
|
||||||
|
"clip_max_weight": 0.5,
|
||||||
|
"clip_min_weight": 0.01,
|
||||||
|
"clip_max_tie_weight": 0.5,
|
||||||
|
"clip_min_tie_weight": 0.01}
|
||||||
|
else:
|
||||||
|
if hasattr(model, "split_input_params"):
|
||||||
|
delattr(model, "split_input_params")
|
||||||
|
|
||||||
|
x_t = None
|
||||||
|
logs = None
|
||||||
|
for n in range(n_runs):
|
||||||
|
if custom_shape is not None:
|
||||||
|
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
||||||
|
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
|
||||||
|
|
||||||
|
logs = make_convolutional_sample(example, model,
|
||||||
|
custom_steps=custom_steps,
|
||||||
|
eta=eta, quantize_x0=False,
|
||||||
|
custom_shape=custom_shape,
|
||||||
|
temperature=temperature, noise_dropout=0.,
|
||||||
|
corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
|
||||||
|
ddim_use_x0_pred=ddim_use_x0_pred
|
||||||
|
)
|
||||||
|
return logs
|
||||||
|
|
||||||
|
def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
|
||||||
|
model = self.load_model_from_config(half_attention)
|
||||||
|
|
||||||
|
# Run settings
|
||||||
|
diffusion_steps = int(steps)
|
||||||
|
eta = 1.0
|
||||||
|
|
||||||
|
down_sample_method = 'Lanczos'
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
im_og = image
|
||||||
|
width_og, height_og = im_og.size
|
||||||
|
# If we can adjust the max upscale size, then the 4 below should be our variable
|
||||||
|
down_sample_rate = target_scale / 4
|
||||||
|
wd = width_og * down_sample_rate
|
||||||
|
hd = height_og * down_sample_rate
|
||||||
|
width_downsampled_pre = int(wd)
|
||||||
|
height_downsampled_pre = int(hd)
|
||||||
|
|
||||||
|
if down_sample_rate != 1:
|
||||||
|
print(
|
||||||
|
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
|
||||||
|
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
||||||
|
else:
|
||||||
|
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
||||||
|
logs = self.run(model["model"], im_og, diffusion_steps, eta)
|
||||||
|
|
||||||
|
sample = logs["sample"]
|
||||||
|
sample = sample.detach().cpu()
|
||||||
|
sample = torch.clamp(sample, -1., 1.)
|
||||||
|
sample = (sample + 1.) / 2. * 255
|
||||||
|
sample = sample.numpy().astype(np.uint8)
|
||||||
|
sample = np.transpose(sample, (0, 2, 3, 1))
|
||||||
|
a = Image.fromarray(sample[0])
|
||||||
|
|
||||||
|
del model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
def get_cond(selected_path):
|
||||||
|
example = dict()
|
||||||
|
up_f = 4
|
||||||
|
c = selected_path.convert('RGB')
|
||||||
|
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
||||||
|
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
|
||||||
|
antialias=True)
|
||||||
|
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
|
||||||
|
c = rearrange(c, '1 c h w -> 1 h w c')
|
||||||
|
c = 2. * c - 1.
|
||||||
|
|
||||||
|
c = c.to(torch.device("cuda"))
|
||||||
|
example["LR_image"] = c
|
||||||
|
example["image"] = c_up
|
||||||
|
|
||||||
|
return example
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
|
||||||
|
mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
|
||||||
|
corrector_kwargs=None, x_t=None
|
||||||
|
):
|
||||||
|
ddim = DDIMSampler(model)
|
||||||
|
bs = shape[0]
|
||||||
|
shape = shape[1:]
|
||||||
|
print(f"Sampling with eta = {eta}; steps: {steps}")
|
||||||
|
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
|
||||||
|
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
|
||||||
|
mask=mask, x0=x0, temperature=temperature, verbose=False,
|
||||||
|
score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs, x_t=x_t)
|
||||||
|
|
||||||
|
return samples, intermediates
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
|
||||||
|
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
|
||||||
|
log = dict()
|
||||||
|
|
||||||
|
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
|
||||||
|
return_first_stage_outputs=True,
|
||||||
|
force_c_encode=not (hasattr(model, 'split_input_params')
|
||||||
|
and model.cond_stage_key == 'coordinates_bbox'),
|
||||||
|
return_original_cond=True)
|
||||||
|
|
||||||
|
if custom_shape is not None:
|
||||||
|
z = torch.randn(custom_shape)
|
||||||
|
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
|
||||||
|
|
||||||
|
z0 = None
|
||||||
|
|
||||||
|
log["input"] = x
|
||||||
|
log["reconstruction"] = xrec
|
||||||
|
|
||||||
|
if ismap(xc):
|
||||||
|
log["original_conditioning"] = model.to_rgb(xc)
|
||||||
|
if hasattr(model, 'cond_stage_key'):
|
||||||
|
log[model.cond_stage_key] = model.to_rgb(xc)
|
||||||
|
|
||||||
|
else:
|
||||||
|
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
|
||||||
|
if model.cond_stage_model:
|
||||||
|
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
|
||||||
|
if model.cond_stage_key == 'class_label':
|
||||||
|
log[model.cond_stage_key] = xc[model.cond_stage_key]
|
||||||
|
|
||||||
|
with model.ema_scope("Plotting"):
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
|
||||||
|
eta=eta,
|
||||||
|
quantize_x0=quantize_x0, mask=None, x0=z0,
|
||||||
|
temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
|
||||||
|
x_t=x_T)
|
||||||
|
t1 = time.time()
|
||||||
|
|
||||||
|
if ddim_use_x0_pred:
|
||||||
|
sample = intermediates['pred_x0'][-1]
|
||||||
|
|
||||||
|
x_sample = model.decode_first_stage(sample)
|
||||||
|
|
||||||
|
try:
|
||||||
|
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
||||||
|
log["sample_noquant"] = x_sample_noquant
|
||||||
|
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
log["sample"] = x_sample
|
||||||
|
log["time"] = t1 - t0
|
||||||
|
|
||||||
|
return log
|
||||||
@ -0,0 +1,140 @@
|
|||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import importlib
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.upscaler import Upscaler
|
||||||
|
from modules.paths import script_path, models_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list:
|
||||||
|
"""
|
||||||
|
A one-and done loader to try finding the desired models in specified directories.
|
||||||
|
|
||||||
|
@param download_name: Specify to download from model_url immediately.
|
||||||
|
@param model_url: If no other models are found, this will be downloaded on upscale.
|
||||||
|
@param model_path: The location to store/find models in.
|
||||||
|
@param command_path: A command-line argument to search for models in first.
|
||||||
|
@param ext_filter: An optional list of filename extensions to filter by
|
||||||
|
@return: A list of paths containing the desired model(s)
|
||||||
|
"""
|
||||||
|
output = []
|
||||||
|
|
||||||
|
if ext_filter is None:
|
||||||
|
ext_filter = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
places = []
|
||||||
|
|
||||||
|
if command_path is not None and command_path != model_path:
|
||||||
|
pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
|
||||||
|
if os.path.exists(pretrained_path):
|
||||||
|
print(f"Appending path: {pretrained_path}")
|
||||||
|
places.append(pretrained_path)
|
||||||
|
elif os.path.exists(command_path):
|
||||||
|
places.append(command_path)
|
||||||
|
|
||||||
|
places.append(model_path)
|
||||||
|
|
||||||
|
for place in places:
|
||||||
|
if os.path.exists(place):
|
||||||
|
for file in glob.iglob(place + '**/**', recursive=True):
|
||||||
|
full_path = file
|
||||||
|
if os.path.isdir(full_path):
|
||||||
|
continue
|
||||||
|
if len(ext_filter) != 0:
|
||||||
|
model_name, extension = os.path.splitext(file)
|
||||||
|
if extension not in ext_filter:
|
||||||
|
continue
|
||||||
|
if file not in output:
|
||||||
|
output.append(full_path)
|
||||||
|
|
||||||
|
if model_url is not None and len(output) == 0:
|
||||||
|
if download_name is not None:
|
||||||
|
dl = load_file_from_url(model_url, model_path, True, download_name)
|
||||||
|
output.append(dl)
|
||||||
|
else:
|
||||||
|
output.append(model_url)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def friendly_name(file: str):
|
||||||
|
if "http" in file:
|
||||||
|
file = urlparse(file).path
|
||||||
|
|
||||||
|
file = os.path.basename(file)
|
||||||
|
model_name, extension = os.path.splitext(file)
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_models():
|
||||||
|
# This code could probably be more efficient if we used a tuple list or something to store the src/destinations
|
||||||
|
# and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
|
||||||
|
# somehow auto-register and just do these things...
|
||||||
|
root_path = script_path
|
||||||
|
src_path = models_path
|
||||||
|
dest_path = os.path.join(models_path, "Stable-diffusion")
|
||||||
|
move_files(src_path, dest_path, ".ckpt")
|
||||||
|
src_path = os.path.join(root_path, "ESRGAN")
|
||||||
|
dest_path = os.path.join(models_path, "ESRGAN")
|
||||||
|
move_files(src_path, dest_path)
|
||||||
|
src_path = os.path.join(root_path, "gfpgan")
|
||||||
|
dest_path = os.path.join(models_path, "GFPGAN")
|
||||||
|
move_files(src_path, dest_path)
|
||||||
|
src_path = os.path.join(root_path, "SwinIR")
|
||||||
|
dest_path = os.path.join(models_path, "SwinIR")
|
||||||
|
move_files(src_path, dest_path)
|
||||||
|
src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
|
||||||
|
dest_path = os.path.join(models_path, "LDSR")
|
||||||
|
move_files(src_path, dest_path)
|
||||||
|
|
||||||
|
|
||||||
|
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
||||||
|
try:
|
||||||
|
if not os.path.exists(dest_path):
|
||||||
|
os.makedirs(dest_path)
|
||||||
|
if os.path.exists(src_path):
|
||||||
|
for file in os.listdir(src_path):
|
||||||
|
fullpath = os.path.join(src_path, file)
|
||||||
|
if os.path.isfile(fullpath):
|
||||||
|
if ext_filter is not None:
|
||||||
|
if ext_filter not in file:
|
||||||
|
continue
|
||||||
|
print(f"Moving {file} from {src_path} to {dest_path}.")
|
||||||
|
try:
|
||||||
|
shutil.move(fullpath, dest_path)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
if len(os.listdir(src_path)) == 0:
|
||||||
|
print(f"Removing empty folder: {src_path}")
|
||||||
|
shutil.rmtree(src_path, True)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def load_upscalers():
|
||||||
|
datas = []
|
||||||
|
for cls in Upscaler.__subclasses__():
|
||||||
|
name = cls.__name__
|
||||||
|
module_name = cls.__module__
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
class_ = getattr(module, name)
|
||||||
|
cmd_name = f"{name.lower().replace('upscaler', '')}-models-path"
|
||||||
|
opt_string = None
|
||||||
|
try:
|
||||||
|
opt_string = shared.opts.__getattr__(cmd_name)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
scaler = class_(opt_string)
|
||||||
|
for child in scaler.scalers:
|
||||||
|
datas.append(child)
|
||||||
|
|
||||||
|
shared.sd_upscalers = datas
|
||||||
@ -0,0 +1,164 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from torch import einsum
|
||||||
|
|
||||||
|
from ldm.util import default
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
||||||
|
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||||
|
h = self.heads
|
||||||
|
|
||||||
|
q = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
k = self.to_k(context)
|
||||||
|
v = self.to_v(context)
|
||||||
|
del context, x
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||||
|
for i in range(0, q.shape[0], 2):
|
||||||
|
end = i + 2
|
||||||
|
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||||
|
s1 *= self.scale
|
||||||
|
|
||||||
|
s2 = s1.softmax(dim=-1)
|
||||||
|
del s1
|
||||||
|
|
||||||
|
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||||
|
del s2
|
||||||
|
|
||||||
|
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
del r1
|
||||||
|
|
||||||
|
return self.to_out(r2)
|
||||||
|
|
||||||
|
|
||||||
|
# taken from https://github.com/Doggettx/stable-diffusion
|
||||||
|
def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||||
|
h = self.heads
|
||||||
|
|
||||||
|
q_in = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
k_in = self.to_k(context) * self.scale
|
||||||
|
v_in = self.to_v(context)
|
||||||
|
del context, x
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||||
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
|
stats = torch.cuda.memory_stats(q.device)
|
||||||
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
|
|
||||||
|
gb = 1024 ** 3
|
||||||
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||||
|
modifier = 3 if q.element_size() == 2 else 2.5
|
||||||
|
mem_required = tensor_size * modifier
|
||||||
|
steps = 1
|
||||||
|
|
||||||
|
if mem_required > mem_free_total:
|
||||||
|
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
|
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||||
|
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||||
|
|
||||||
|
if steps > 64:
|
||||||
|
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||||
|
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||||
|
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||||
|
|
||||||
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
|
for i in range(0, q.shape[1], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||||
|
|
||||||
|
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||||
|
del s1
|
||||||
|
|
||||||
|
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||||
|
del s2
|
||||||
|
|
||||||
|
del q, k, v
|
||||||
|
|
||||||
|
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
del r1
|
||||||
|
|
||||||
|
return self.to_out(r2)
|
||||||
|
|
||||||
|
def nonlinearity_hijack(x):
|
||||||
|
# swish
|
||||||
|
t = torch.sigmoid(x)
|
||||||
|
x *= t
|
||||||
|
del t
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def cross_attention_attnblock_forward(self, x):
|
||||||
|
h_ = x
|
||||||
|
h_ = self.norm(h_)
|
||||||
|
q1 = self.q(h_)
|
||||||
|
k1 = self.k(h_)
|
||||||
|
v = self.v(h_)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
b, c, h, w = q1.shape
|
||||||
|
|
||||||
|
q2 = q1.reshape(b, c, h*w)
|
||||||
|
del q1
|
||||||
|
|
||||||
|
q = q2.permute(0, 2, 1) # b,hw,c
|
||||||
|
del q2
|
||||||
|
|
||||||
|
k = k1.reshape(b, c, h*w) # b,c,hw
|
||||||
|
del k1
|
||||||
|
|
||||||
|
h_ = torch.zeros_like(k, device=q.device)
|
||||||
|
|
||||||
|
stats = torch.cuda.memory_stats(q.device)
|
||||||
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
|
|
||||||
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||||
|
mem_required = tensor_size * 2.5
|
||||||
|
steps = 1
|
||||||
|
|
||||||
|
if mem_required > mem_free_total:
|
||||||
|
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
|
|
||||||
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
|
for i in range(0, q.shape[1], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
|
||||||
|
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||||
|
w2 = w1 * (int(c)**(-0.5))
|
||||||
|
del w1
|
||||||
|
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
|
||||||
|
del w2
|
||||||
|
|
||||||
|
# attend to values
|
||||||
|
v1 = v.reshape(b, c, h*w)
|
||||||
|
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||||
|
del w3
|
||||||
|
|
||||||
|
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||||
|
del v1, w4
|
||||||
|
|
||||||
|
h2 = h_.reshape(b, c, h, w)
|
||||||
|
del h_
|
||||||
|
|
||||||
|
h3 = self.proj_out(h2)
|
||||||
|
del h2
|
||||||
|
|
||||||
|
h3 += x
|
||||||
|
|
||||||
|
return h3
|
||||||
@ -1,123 +0,0 @@
|
|||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import cv2
|
|
||||||
import os
|
|
||||||
import contextlib
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
import modules.images
|
|
||||||
from modules.shared import cmd_opts, opts, device
|
|
||||||
from modules.swinir_arch import SwinIR as net
|
|
||||||
|
|
||||||
precision_scope = (
|
|
||||||
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(filename, scale=4):
|
|
||||||
model = net(
|
|
||||||
upscale=scale,
|
|
||||||
in_chans=3,
|
|
||||||
img_size=64,
|
|
||||||
window_size=8,
|
|
||||||
img_range=1.0,
|
|
||||||
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
|
|
||||||
embed_dim=240,
|
|
||||||
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
|
|
||||||
mlp_ratio=2,
|
|
||||||
upsampler="nearest+conv",
|
|
||||||
resi_connection="3conv",
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained_model = torch.load(filename)
|
|
||||||
model.load_state_dict(pretrained_model["params_ema"], strict=True)
|
|
||||||
if not cmd_opts.no_half:
|
|
||||||
model = model.half()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_models(dirname):
|
|
||||||
for file in os.listdir(dirname):
|
|
||||||
path = os.path.join(dirname, file)
|
|
||||||
model_name, extension = os.path.splitext(file)
|
|
||||||
|
|
||||||
if extension != ".pt" and extension != ".pth":
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
modules.shared.sd_upscalers.append(UpscalerSwin(path, model_name))
|
|
||||||
except Exception:
|
|
||||||
print(f"Error loading SwinIR model: {path}", file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
|
|
||||||
def upscale(
|
|
||||||
img,
|
|
||||||
model,
|
|
||||||
tile=opts.SWIN_tile,
|
|
||||||
tile_overlap=opts.SWIN_tile_overlap,
|
|
||||||
window_size=8,
|
|
||||||
scale=4,
|
|
||||||
):
|
|
||||||
img = np.array(img)
|
|
||||||
img = img[:, :, ::-1]
|
|
||||||
img = np.moveaxis(img, 2, 0) / 255
|
|
||||||
img = torch.from_numpy(img).float()
|
|
||||||
img = img.unsqueeze(0).to(device)
|
|
||||||
with torch.no_grad(), precision_scope("cuda"):
|
|
||||||
_, _, h_old, w_old = img.size()
|
|
||||||
h_pad = (h_old // window_size + 1) * window_size - h_old
|
|
||||||
w_pad = (w_old // window_size + 1) * window_size - w_old
|
|
||||||
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
|
|
||||||
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
|
|
||||||
output = inference(img, model, tile, tile_overlap, window_size, scale)
|
|
||||||
output = output[..., : h_old * scale, : w_old * scale]
|
|
||||||
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
||||||
if output.ndim == 3:
|
|
||||||
output = np.transpose(
|
|
||||||
output[[2, 1, 0], :, :], (1, 2, 0)
|
|
||||||
) # CHW-RGB to HCW-BGR
|
|
||||||
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
|
|
||||||
return Image.fromarray(output, "RGB")
|
|
||||||
|
|
||||||
|
|
||||||
def inference(img, model, tile, tile_overlap, window_size, scale):
|
|
||||||
# test the image tile by tile
|
|
||||||
b, c, h, w = img.size()
|
|
||||||
tile = min(tile, h, w)
|
|
||||||
assert tile % window_size == 0, "tile size should be a multiple of window_size"
|
|
||||||
sf = scale
|
|
||||||
|
|
||||||
stride = tile - tile_overlap
|
|
||||||
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
|
|
||||||
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
|
|
||||||
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
|
|
||||||
W = torch.zeros_like(E, dtype=torch.half, device=device)
|
|
||||||
|
|
||||||
for h_idx in h_idx_list:
|
|
||||||
for w_idx in w_idx_list:
|
|
||||||
in_patch = img[..., h_idx : h_idx + tile, w_idx : w_idx + tile]
|
|
||||||
out_patch = model(in_patch)
|
|
||||||
out_patch_mask = torch.ones_like(out_patch)
|
|
||||||
|
|
||||||
E[
|
|
||||||
..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf
|
|
||||||
].add_(out_patch)
|
|
||||||
W[
|
|
||||||
..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf
|
|
||||||
].add_(out_patch_mask)
|
|
||||||
output = E.div_(W)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class UpscalerSwin(modules.images.Upscaler):
|
|
||||||
def __init__(self, filename, title):
|
|
||||||
self.name = title
|
|
||||||
self.model = load_model(filename)
|
|
||||||
|
|
||||||
def do_upscale(self, img):
|
|
||||||
model = self.model.to(device)
|
|
||||||
img = upscale(img, model)
|
|
||||||
return img
|
|
||||||
@ -0,0 +1,142 @@
|
|||||||
|
import contextlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from modules import modelloader
|
||||||
|
from modules.paths import models_path
|
||||||
|
from modules.shared import cmd_opts, opts, device
|
||||||
|
from modules.swinir_model_arch import SwinIR as net
|
||||||
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
|
||||||
|
precision_scope = (
|
||||||
|
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UpscalerSwinIR(Upscaler):
|
||||||
|
def __init__(self, dirname):
|
||||||
|
self.name = "SwinIR"
|
||||||
|
self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
|
||||||
|
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
|
||||||
|
"-L_x4_GAN.pth "
|
||||||
|
self.model_name = "SwinIR 4x"
|
||||||
|
self.model_path = os.path.join(models_path, self.name)
|
||||||
|
self.user_path = dirname
|
||||||
|
super().__init__()
|
||||||
|
scalers = []
|
||||||
|
model_files = self.find_models(ext_filter=[".pt", ".pth"])
|
||||||
|
for model in model_files:
|
||||||
|
if "http" in model:
|
||||||
|
name = self.model_name
|
||||||
|
else:
|
||||||
|
name = modelloader.friendly_name(model)
|
||||||
|
model_data = UpscalerData(name, model, self)
|
||||||
|
scalers.append(model_data)
|
||||||
|
self.scalers = scalers
|
||||||
|
|
||||||
|
def do_upscale(self, img, model_file):
|
||||||
|
model = self.load_model(model_file)
|
||||||
|
if model is None:
|
||||||
|
return img
|
||||||
|
model = model.to(device)
|
||||||
|
img = upscale(img, model)
|
||||||
|
try:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return img
|
||||||
|
|
||||||
|
def load_model(self, path, scale=4):
|
||||||
|
if "http" in path:
|
||||||
|
dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
|
||||||
|
filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True)
|
||||||
|
else:
|
||||||
|
filename = path
|
||||||
|
if filename is None or not os.path.exists(filename):
|
||||||
|
return None
|
||||||
|
model = net(
|
||||||
|
upscale=scale,
|
||||||
|
in_chans=3,
|
||||||
|
img_size=64,
|
||||||
|
window_size=8,
|
||||||
|
img_range=1.0,
|
||||||
|
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
|
||||||
|
embed_dim=240,
|
||||||
|
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
|
||||||
|
mlp_ratio=2,
|
||||||
|
upsampler="nearest+conv",
|
||||||
|
resi_connection="3conv",
|
||||||
|
)
|
||||||
|
|
||||||
|
pretrained_model = torch.load(filename)
|
||||||
|
model.load_state_dict(pretrained_model["params_ema"], strict=True)
|
||||||
|
if not cmd_opts.no_half:
|
||||||
|
model = model.half()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def upscale(
|
||||||
|
img,
|
||||||
|
model,
|
||||||
|
tile=opts.SWIN_tile,
|
||||||
|
tile_overlap=opts.SWIN_tile_overlap,
|
||||||
|
window_size=8,
|
||||||
|
scale=4,
|
||||||
|
):
|
||||||
|
img = np.array(img)
|
||||||
|
img = img[:, :, ::-1]
|
||||||
|
img = np.moveaxis(img, 2, 0) / 255
|
||||||
|
img = torch.from_numpy(img).float()
|
||||||
|
img = img.unsqueeze(0).to(device)
|
||||||
|
with torch.no_grad(), precision_scope("cuda"):
|
||||||
|
_, _, h_old, w_old = img.size()
|
||||||
|
h_pad = (h_old // window_size + 1) * window_size - h_old
|
||||||
|
w_pad = (w_old // window_size + 1) * window_size - w_old
|
||||||
|
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
|
||||||
|
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
|
||||||
|
output = inference(img, model, tile, tile_overlap, window_size, scale)
|
||||||
|
output = output[..., : h_old * scale, : w_old * scale]
|
||||||
|
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
|
if output.ndim == 3:
|
||||||
|
output = np.transpose(
|
||||||
|
output[[2, 1, 0], :, :], (1, 2, 0)
|
||||||
|
) # CHW-RGB to HCW-BGR
|
||||||
|
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
|
||||||
|
return Image.fromarray(output, "RGB")
|
||||||
|
|
||||||
|
|
||||||
|
def inference(img, model, tile, tile_overlap, window_size, scale):
|
||||||
|
# test the image tile by tile
|
||||||
|
b, c, h, w = img.size()
|
||||||
|
tile = min(tile, h, w)
|
||||||
|
assert tile % window_size == 0, "tile size should be a multiple of window_size"
|
||||||
|
sf = scale
|
||||||
|
|
||||||
|
stride = tile - tile_overlap
|
||||||
|
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
|
||||||
|
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
|
||||||
|
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
|
||||||
|
W = torch.zeros_like(E, dtype=torch.half, device=device)
|
||||||
|
|
||||||
|
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
|
||||||
|
for h_idx in h_idx_list:
|
||||||
|
for w_idx in w_idx_list:
|
||||||
|
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
||||||
|
out_patch = model(in_patch)
|
||||||
|
out_patch_mask = torch.ones_like(out_patch)
|
||||||
|
|
||||||
|
E[
|
||||||
|
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
||||||
|
].add_(out_patch)
|
||||||
|
W[
|
||||||
|
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
||||||
|
].add_(out_patch_mask)
|
||||||
|
pbar.update(1)
|
||||||
|
output = E.div_(W)
|
||||||
|
|
||||||
|
return output
|
||||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,76 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
import random
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class PersonalizedBase(Dataset):
|
||||||
|
def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
|
||||||
|
|
||||||
|
self.placeholder_token = placeholder_token
|
||||||
|
|
||||||
|
self.size = size
|
||||||
|
self.width = width
|
||||||
|
self.height = height
|
||||||
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
|
|
||||||
|
self.dataset = []
|
||||||
|
|
||||||
|
with open(template_file, "r") as file:
|
||||||
|
lines = [x.strip() for x in file.readlines()]
|
||||||
|
|
||||||
|
self.lines = lines
|
||||||
|
|
||||||
|
assert data_root, 'dataset directory not specified'
|
||||||
|
|
||||||
|
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||||
|
print("Preparing dataset...")
|
||||||
|
for path in tqdm.tqdm(self.image_paths):
|
||||||
|
image = Image.open(path)
|
||||||
|
image = image.convert('RGB')
|
||||||
|
image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
|
||||||
|
|
||||||
|
filename = os.path.basename(path)
|
||||||
|
filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-')
|
||||||
|
filename_tokens = [token for token in filename_tokens if token.isalpha()]
|
||||||
|
|
||||||
|
npimage = np.array(image).astype(np.uint8)
|
||||||
|
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||||
|
|
||||||
|
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
|
||||||
|
torchdata = torch.moveaxis(torchdata, 2, 0)
|
||||||
|
|
||||||
|
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
|
||||||
|
|
||||||
|
self.dataset.append((init_latent, filename_tokens))
|
||||||
|
|
||||||
|
self.length = len(self.dataset) * repeats
|
||||||
|
|
||||||
|
self.initial_indexes = np.arange(self.length) % len(self.dataset)
|
||||||
|
self.indexes = None
|
||||||
|
self.shuffle()
|
||||||
|
|
||||||
|
def shuffle(self):
|
||||||
|
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
if i % len(self.dataset) == 0:
|
||||||
|
self.shuffle()
|
||||||
|
|
||||||
|
index = self.indexes[i % len(self.indexes)]
|
||||||
|
x, filename_tokens = self.dataset[index]
|
||||||
|
|
||||||
|
text = random.choice(self.lines)
|
||||||
|
text = text.replace("[name]", self.placeholder_token)
|
||||||
|
text = text.replace("[filewords]", ' '.join(filename_tokens))
|
||||||
|
|
||||||
|
return x, text
|
||||||
@ -0,0 +1,258 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
import html
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from modules import shared, devices, sd_hijack, processing
|
||||||
|
import modules.textual_inversion.dataset
|
||||||
|
|
||||||
|
|
||||||
|
class Embedding:
|
||||||
|
def __init__(self, vec, name, step=None):
|
||||||
|
self.vec = vec
|
||||||
|
self.name = name
|
||||||
|
self.step = step
|
||||||
|
self.cached_checksum = None
|
||||||
|
|
||||||
|
def save(self, filename):
|
||||||
|
embedding_data = {
|
||||||
|
"string_to_token": {"*": 265},
|
||||||
|
"string_to_param": {"*": self.vec},
|
||||||
|
"name": self.name,
|
||||||
|
"step": self.step,
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.save(embedding_data, filename)
|
||||||
|
|
||||||
|
def checksum(self):
|
||||||
|
if self.cached_checksum is not None:
|
||||||
|
return self.cached_checksum
|
||||||
|
|
||||||
|
def const_hash(a):
|
||||||
|
r = 0
|
||||||
|
for v in a:
|
||||||
|
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
|
||||||
|
return r
|
||||||
|
|
||||||
|
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
|
||||||
|
return self.cached_checksum
|
||||||
|
|
||||||
|
class EmbeddingDatabase:
|
||||||
|
def __init__(self, embeddings_dir):
|
||||||
|
self.ids_lookup = {}
|
||||||
|
self.word_embeddings = {}
|
||||||
|
self.dir_mtime = None
|
||||||
|
self.embeddings_dir = embeddings_dir
|
||||||
|
|
||||||
|
def register_embedding(self, embedding, model):
|
||||||
|
|
||||||
|
self.word_embeddings[embedding.name] = embedding
|
||||||
|
|
||||||
|
ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
|
||||||
|
|
||||||
|
first_id = ids[0]
|
||||||
|
if first_id not in self.ids_lookup:
|
||||||
|
self.ids_lookup[first_id] = []
|
||||||
|
self.ids_lookup[first_id].append((ids, embedding))
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def load_textual_inversion_embeddings(self):
|
||||||
|
mt = os.path.getmtime(self.embeddings_dir)
|
||||||
|
if self.dir_mtime is not None and mt <= self.dir_mtime:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.dir_mtime = mt
|
||||||
|
self.ids_lookup.clear()
|
||||||
|
self.word_embeddings.clear()
|
||||||
|
|
||||||
|
def process_file(path, filename):
|
||||||
|
name = os.path.splitext(filename)[0]
|
||||||
|
|
||||||
|
data = torch.load(path, map_location="cpu")
|
||||||
|
|
||||||
|
# textual inversion embeddings
|
||||||
|
if 'string_to_param' in data:
|
||||||
|
param_dict = data['string_to_param']
|
||||||
|
if hasattr(param_dict, '_parameters'):
|
||||||
|
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||||
|
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||||
|
emb = next(iter(param_dict.items()))[1]
|
||||||
|
# diffuser concepts
|
||||||
|
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||||
|
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||||
|
|
||||||
|
emb = next(iter(data.values()))
|
||||||
|
if len(emb.shape) == 1:
|
||||||
|
emb = emb.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||||
|
|
||||||
|
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||||
|
embedding = Embedding(vec, name)
|
||||||
|
embedding.step = data.get('step', None)
|
||||||
|
self.register_embedding(embedding, shared.sd_model)
|
||||||
|
|
||||||
|
for fn in os.listdir(self.embeddings_dir):
|
||||||
|
try:
|
||||||
|
fullfn = os.path.join(self.embeddings_dir, fn)
|
||||||
|
|
||||||
|
if os.stat(fullfn).st_size == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
process_file(fullfn, fn)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error loading emedding {fn}:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
|
||||||
|
|
||||||
|
def find_embedding_at_position(self, tokens, offset):
|
||||||
|
token = tokens[offset]
|
||||||
|
possible_matches = self.ids_lookup.get(token, None)
|
||||||
|
|
||||||
|
if possible_matches is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for ids, embedding in possible_matches:
|
||||||
|
if tokens[offset:offset + len(ids)] == ids:
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding(name, num_vectors_per_token):
|
||||||
|
init_text = '*'
|
||||||
|
|
||||||
|
cond_model = shared.sd_model.cond_stage_model
|
||||||
|
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
||||||
|
|
||||||
|
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||||
|
embedded = embedding_layer(ids.to(devices.device)).squeeze(0)
|
||||||
|
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
||||||
|
|
||||||
|
for i in range(num_vectors_per_token):
|
||||||
|
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
||||||
|
|
||||||
|
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
||||||
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
|
|
||||||
|
embedding = Embedding(vec, name)
|
||||||
|
embedding.step = 0
|
||||||
|
embedding.save(fn)
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
|
||||||
|
assert embedding_name, 'embedding not selected'
|
||||||
|
|
||||||
|
shared.state.textinfo = "Initializing textual inversion training..."
|
||||||
|
shared.state.job_count = steps
|
||||||
|
|
||||||
|
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||||
|
|
||||||
|
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name)
|
||||||
|
|
||||||
|
if save_embedding_every > 0:
|
||||||
|
embedding_dir = os.path.join(log_directory, "embeddings")
|
||||||
|
os.makedirs(embedding_dir, exist_ok=True)
|
||||||
|
else:
|
||||||
|
embedding_dir = None
|
||||||
|
|
||||||
|
if create_image_every > 0:
|
||||||
|
images_dir = os.path.join(log_directory, "images")
|
||||||
|
os.makedirs(images_dir, exist_ok=True)
|
||||||
|
else:
|
||||||
|
images_dir = None
|
||||||
|
|
||||||
|
cond_model = shared.sd_model.cond_stage_model
|
||||||
|
|
||||||
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
||||||
|
|
||||||
|
hijack = sd_hijack.model_hijack
|
||||||
|
|
||||||
|
embedding = hijack.embedding_db.word_embeddings[embedding_name]
|
||||||
|
embedding.vec.requires_grad = True
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
|
||||||
|
|
||||||
|
losses = torch.zeros((32,))
|
||||||
|
|
||||||
|
last_saved_file = "<none>"
|
||||||
|
last_saved_image = "<none>"
|
||||||
|
|
||||||
|
ititial_step = embedding.step or 0
|
||||||
|
if ititial_step > steps:
|
||||||
|
return embedding, filename
|
||||||
|
|
||||||
|
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||||
|
for i, (x, text) in pbar:
|
||||||
|
embedding.step = i + ititial_step
|
||||||
|
|
||||||
|
if embedding.step > steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
if shared.state.interrupted:
|
||||||
|
break
|
||||||
|
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
c = cond_model([text])
|
||||||
|
loss = shared.sd_model(x.unsqueeze(0), c)[0]
|
||||||
|
|
||||||
|
losses[embedding.step % losses.shape[0]] = loss.item()
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
pbar.set_description(f"loss: {losses.mean():.7f}")
|
||||||
|
|
||||||
|
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
||||||
|
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
||||||
|
embedding.save(last_saved_file)
|
||||||
|
|
||||||
|
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
||||||
|
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|
||||||
|
|
||||||
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
|
sd_model=shared.sd_model,
|
||||||
|
prompt=text,
|
||||||
|
steps=20,
|
||||||
|
do_not_save_grid=True,
|
||||||
|
do_not_save_samples=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
processed = processing.process_images(p)
|
||||||
|
image = processed.images[0]
|
||||||
|
|
||||||
|
shared.state.current_image = image
|
||||||
|
image.save(last_saved_image)
|
||||||
|
|
||||||
|
last_saved_image += f", prompt: {text}"
|
||||||
|
|
||||||
|
shared.state.job_no = embedding.step
|
||||||
|
|
||||||
|
shared.state.textinfo = f"""
|
||||||
|
<p>
|
||||||
|
Loss: {losses.mean():.7f}<br/>
|
||||||
|
Step: {embedding.step}<br/>
|
||||||
|
Last prompt: {html.escape(text)}<br/>
|
||||||
|
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||||
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
|
</p>
|
||||||
|
"""
|
||||||
|
|
||||||
|
embedding.cached_checksum = None
|
||||||
|
embedding.save(filename)
|
||||||
|
|
||||||
|
return embedding, filename
|
||||||
|
|
||||||
@ -0,0 +1,32 @@
|
|||||||
|
import html
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
import modules.textual_inversion.textual_inversion as ti
|
||||||
|
from modules import sd_hijack, shared
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding(name, nvpt):
|
||||||
|
filename = ti.create_embedding(name, nvpt)
|
||||||
|
|
||||||
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
|
||||||
|
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
|
||||||
|
|
||||||
|
|
||||||
|
def train_embedding(*args):
|
||||||
|
|
||||||
|
try:
|
||||||
|
sd_hijack.undo_optimizations()
|
||||||
|
|
||||||
|
embedding, filename = ti.train_embedding(*args)
|
||||||
|
|
||||||
|
res = f"""
|
||||||
|
Training {'interrupted' if shared.state.interrupted else 'finished'} after {embedding.step} steps.
|
||||||
|
Embedding saved to {html.escape(filename)}
|
||||||
|
"""
|
||||||
|
return res, ""
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
sd_hijack.apply_optimizations()
|
||||||
@ -0,0 +1,121 @@
|
|||||||
|
import os
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
import PIL
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import modules.shared
|
||||||
|
from modules import modelloader, shared
|
||||||
|
|
||||||
|
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||||
|
from modules.paths import models_path
|
||||||
|
|
||||||
|
|
||||||
|
class Upscaler:
|
||||||
|
name = None
|
||||||
|
model_path = None
|
||||||
|
model_name = None
|
||||||
|
model_url = None
|
||||||
|
enable = True
|
||||||
|
filter = None
|
||||||
|
model = None
|
||||||
|
user_path = None
|
||||||
|
scalers: []
|
||||||
|
tile = True
|
||||||
|
|
||||||
|
def __init__(self, create_dirs=False):
|
||||||
|
self.mod_pad_h = None
|
||||||
|
self.tile_size = modules.shared.opts.ESRGAN_tile
|
||||||
|
self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
|
||||||
|
self.device = modules.shared.device
|
||||||
|
self.img = None
|
||||||
|
self.output = None
|
||||||
|
self.scale = 1
|
||||||
|
self.half = not modules.shared.cmd_opts.no_half
|
||||||
|
self.pre_pad = 0
|
||||||
|
self.mod_scale = None
|
||||||
|
if self.name is not None and create_dirs:
|
||||||
|
self.model_path = os.path.join(models_path, self.name)
|
||||||
|
if not os.path.exists(self.model_path):
|
||||||
|
os.makedirs(self.model_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
self.can_tile = True
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def do_upscale(self, img: PIL.Image, selected_model: str):
|
||||||
|
return img
|
||||||
|
|
||||||
|
def upscale(self, img: PIL.Image, scale: int, selected_model: str = None):
|
||||||
|
self.scale = scale
|
||||||
|
dest_w = img.width * scale
|
||||||
|
dest_h = img.height * scale
|
||||||
|
for i in range(3):
|
||||||
|
if img.width >= dest_w and img.height >= dest_h:
|
||||||
|
break
|
||||||
|
img = self.do_upscale(img, selected_model)
|
||||||
|
if img.width != dest_w or img.height != dest_h:
|
||||||
|
img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_model(self, path: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def find_models(self, ext_filter=None) -> list:
|
||||||
|
return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path)
|
||||||
|
|
||||||
|
def update_status(self, prompt):
|
||||||
|
print(f"\nextras: {prompt}", file=shared.progress_print_out)
|
||||||
|
|
||||||
|
|
||||||
|
class UpscalerData:
|
||||||
|
name = None
|
||||||
|
data_path = None
|
||||||
|
scale: int = 4
|
||||||
|
scaler: Upscaler = None
|
||||||
|
model: None
|
||||||
|
|
||||||
|
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
|
||||||
|
self.name = name
|
||||||
|
self.data_path = path
|
||||||
|
self.scaler = upscaler
|
||||||
|
self.scale = scale
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
|
||||||
|
class UpscalerNone(Upscaler):
|
||||||
|
name = "None"
|
||||||
|
scalers = []
|
||||||
|
|
||||||
|
def load_model(self, path):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def do_upscale(self, img, selected_model=None):
|
||||||
|
return img
|
||||||
|
|
||||||
|
def __init__(self, dirname=None):
|
||||||
|
super().__init__(False)
|
||||||
|
self.scalers = [UpscalerData("None", None, self)]
|
||||||
|
|
||||||
|
|
||||||
|
class UpscalerLanczos(Upscaler):
|
||||||
|
scalers = []
|
||||||
|
|
||||||
|
def do_upscale(self, img, selected_model=None):
|
||||||
|
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
|
||||||
|
|
||||||
|
def load_model(self, _):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __init__(self, dirname=None):
|
||||||
|
super().__init__(False)
|
||||||
|
self.name = "Lanczos"
|
||||||
|
self.scalers = [UpscalerData("Lanczos", None, self)]
|
||||||
|
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
a painting, art by [name]
|
||||||
|
a rendering, art by [name]
|
||||||
|
a cropped painting, art by [name]
|
||||||
|
the painting, art by [name]
|
||||||
|
a clean painting, art by [name]
|
||||||
|
a dirty painting, art by [name]
|
||||||
|
a dark painting, art by [name]
|
||||||
|
a picture, art by [name]
|
||||||
|
a cool painting, art by [name]
|
||||||
|
a close-up painting, art by [name]
|
||||||
|
a bright painting, art by [name]
|
||||||
|
a cropped painting, art by [name]
|
||||||
|
a good painting, art by [name]
|
||||||
|
a close-up painting, art by [name]
|
||||||
|
a rendition, art by [name]
|
||||||
|
a nice painting, art by [name]
|
||||||
|
a small painting, art by [name]
|
||||||
|
a weird painting, art by [name]
|
||||||
|
a large painting, art by [name]
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
a painting of [filewords], art by [name]
|
||||||
|
a rendering of [filewords], art by [name]
|
||||||
|
a cropped painting of [filewords], art by [name]
|
||||||
|
the painting of [filewords], art by [name]
|
||||||
|
a clean painting of [filewords], art by [name]
|
||||||
|
a dirty painting of [filewords], art by [name]
|
||||||
|
a dark painting of [filewords], art by [name]
|
||||||
|
a picture of [filewords], art by [name]
|
||||||
|
a cool painting of [filewords], art by [name]
|
||||||
|
a close-up painting of [filewords], art by [name]
|
||||||
|
a bright painting of [filewords], art by [name]
|
||||||
|
a cropped painting of [filewords], art by [name]
|
||||||
|
a good painting of [filewords], art by [name]
|
||||||
|
a close-up painting of [filewords], art by [name]
|
||||||
|
a rendition of [filewords], art by [name]
|
||||||
|
a nice painting of [filewords], art by [name]
|
||||||
|
a small painting of [filewords], art by [name]
|
||||||
|
a weird painting of [filewords], art by [name]
|
||||||
|
a large painting of [filewords], art by [name]
|
||||||
@ -0,0 +1,27 @@
|
|||||||
|
a photo of a [name]
|
||||||
|
a rendering of a [name]
|
||||||
|
a cropped photo of the [name]
|
||||||
|
the photo of a [name]
|
||||||
|
a photo of a clean [name]
|
||||||
|
a photo of a dirty [name]
|
||||||
|
a dark photo of the [name]
|
||||||
|
a photo of my [name]
|
||||||
|
a photo of the cool [name]
|
||||||
|
a close-up photo of a [name]
|
||||||
|
a bright photo of the [name]
|
||||||
|
a cropped photo of a [name]
|
||||||
|
a photo of the [name]
|
||||||
|
a good photo of the [name]
|
||||||
|
a photo of one [name]
|
||||||
|
a close-up photo of the [name]
|
||||||
|
a rendition of the [name]
|
||||||
|
a photo of the clean [name]
|
||||||
|
a rendition of a [name]
|
||||||
|
a photo of a nice [name]
|
||||||
|
a good photo of a [name]
|
||||||
|
a photo of the nice [name]
|
||||||
|
a photo of the small [name]
|
||||||
|
a photo of the weird [name]
|
||||||
|
a photo of the large [name]
|
||||||
|
a photo of a cool [name]
|
||||||
|
a photo of a small [name]
|
||||||
@ -0,0 +1,27 @@
|
|||||||
|
a photo of a [name], [filewords]
|
||||||
|
a rendering of a [name], [filewords]
|
||||||
|
a cropped photo of the [name], [filewords]
|
||||||
|
the photo of a [name], [filewords]
|
||||||
|
a photo of a clean [name], [filewords]
|
||||||
|
a photo of a dirty [name], [filewords]
|
||||||
|
a dark photo of the [name], [filewords]
|
||||||
|
a photo of my [name], [filewords]
|
||||||
|
a photo of the cool [name], [filewords]
|
||||||
|
a close-up photo of a [name], [filewords]
|
||||||
|
a bright photo of the [name], [filewords]
|
||||||
|
a cropped photo of a [name], [filewords]
|
||||||
|
a photo of the [name], [filewords]
|
||||||
|
a good photo of the [name], [filewords]
|
||||||
|
a photo of one [name], [filewords]
|
||||||
|
a close-up photo of the [name], [filewords]
|
||||||
|
a rendition of the [name], [filewords]
|
||||||
|
a photo of the clean [name], [filewords]
|
||||||
|
a rendition of a [name], [filewords]
|
||||||
|
a photo of a nice [name], [filewords]
|
||||||
|
a good photo of a [name], [filewords]
|
||||||
|
a photo of the nice [name], [filewords]
|
||||||
|
a photo of the small [name], [filewords]
|
||||||
|
a photo of the weird [name], [filewords]
|
||||||
|
a photo of the large [name], [filewords]
|
||||||
|
a photo of a cool [name], [filewords]
|
||||||
|
a photo of a small [name], [filewords]
|
||||||
Loading…
Reference in New Issue