Holy $hit.
Yep. Fix gfpgan_model_arch requirement(s). Add Upscaler base class, move from images. Add a lot of methods to Upscaler. Re-work all the child upscalers to be proper classes. Add BSRGAN scaler. Add ldsr_model_arch class, removing the dependency for another repo that just uses regular latent-diffusion stuff. Add one universal method that will always find and load new upscaler models without having to add new "setup_model" calls. Still need to add command line params, but that could probably be automated. Add a "self.scale" property to all Upscalers so the scalers themselves can do "things" in response to the requested upscaling size. Ensure LDSR doesn't get stuck in a longer loop of "upscale/downscale/upscale" as we try to reach the target upscale size. Add typehints for IDE sanity. PEP-8 improvements. Moar.master
parent
31ad536c33
commit
0dce0df1ee
@ -0,0 +1,79 @@
|
||||
import os.path
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
import torch
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.upscaler
|
||||
from modules import shared, modelloader
|
||||
from modules.bsrgan_model_arch import RRDBNet
|
||||
from modules.paths import models_path
|
||||
|
||||
|
||||
class UpscalerBSRGAN(modules.upscaler.Upscaler):
|
||||
def __init__(self, dirname):
|
||||
self.name = "BSRGAN"
|
||||
self.model_path = os.path.join(models_path, self.name)
|
||||
self.model_name = "BSRGAN 4x"
|
||||
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
|
||||
self.user_path = dirname
|
||||
super().__init__()
|
||||
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
||||
scalers = []
|
||||
if len(model_paths) == 0:
|
||||
scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
|
||||
scalers.append(scaler_data)
|
||||
for file in model_paths:
|
||||
if "http" in file:
|
||||
name = self.model_name
|
||||
else:
|
||||
name = modelloader.friendly_name(file)
|
||||
try:
|
||||
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
||||
scalers.append(scaler_data)
|
||||
except Exception:
|
||||
print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
self.scalers = scalers
|
||||
|
||||
def do_upscale(self, img: PIL.Image, selected_file):
|
||||
torch.cuda.empty_cache()
|
||||
model = self.load_model(selected_file)
|
||||
if model is None:
|
||||
return img
|
||||
model.to(shared.device)
|
||||
torch.cuda.empty_cache()
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.moveaxis(img, 2, 0) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(shared.device)
|
||||
with torch.no_grad():
|
||||
output = model(img)
|
||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output = 255. * np.moveaxis(output, 0, 2)
|
||||
output = output.astype(np.uint8)
|
||||
output = output[:, :, ::-1]
|
||||
torch.cuda.empty_cache()
|
||||
return PIL.Image.fromarray(output, 'RGB')
|
||||
|
||||
def load_model(self, path: str):
|
||||
if "http" in path:
|
||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
||||
progress=True)
|
||||
else:
|
||||
filename = path
|
||||
if not os.path.exists(filename) or filename is None:
|
||||
print("Unable to load %s from %s" % (self.model_dir, filename))
|
||||
return None
|
||||
print("Loading %s from %s" % (self.model_dir, filename))
|
||||
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2) # define network
|
||||
model.load_state_dict(torch.load(filename), strict=True)
|
||||
model.eval()
|
||||
for k, v in model.named_parameters():
|
||||
v.requires_grad = False
|
||||
return model
|
||||
|
||||
@ -0,0 +1,103 @@
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
|
||||
|
||||
def initialize_weights(net_l, scale=1):
|
||||
if not isinstance(net_l, list):
|
||||
net_l = [net_l]
|
||||
for net in net_l:
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
m.weight.data *= scale # for residual block
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
m.weight.data *= scale
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
|
||||
def make_layer(block, n_layers):
|
||||
layers = []
|
||||
for _ in range(n_layers):
|
||||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
# initialization
|
||||
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
'''Residual in Residual Dense Block'''
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
self.sf = sf
|
||||
print([in_nc, out_nc, nf, nb, gc, sf])
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
if self.sf==4:
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
fea = fea + trunk
|
||||
|
||||
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
if self.sf==4:
|
||||
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
return out
|
||||
@ -1,150 +0,0 @@
|
||||
# GFPGAN likes to download stuff "wherever", and we're trying to fix that, so this is a copy of the original...
|
||||
|
||||
import cv2
|
||||
import os
|
||||
import torch
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear
|
||||
from gfpgan.archs.gfpganv1_arch import GFPGANv1
|
||||
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
|
||||
|
||||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
class GFPGANerr():
|
||||
"""Helper for restoration with GFPGAN.
|
||||
|
||||
It will detect and crop faces, and then resize the faces to 512x512.
|
||||
GFPGAN is used to restored the resized faces.
|
||||
The background is upsampled with the bg_upsampler.
|
||||
Finally, the faces will be pasted back to the upsample background image.
|
||||
|
||||
Args:
|
||||
model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
|
||||
upscale (float): The upscale of the final output. Default: 2.
|
||||
arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
bg_upsampler (nn.Module): The upsampler for the background. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path, model_dir, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None):
|
||||
self.upscale = upscale
|
||||
self.bg_upsampler = bg_upsampler
|
||||
|
||||
# initialize model
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
||||
# initialize the GFP-GAN
|
||||
if arch == 'clean':
|
||||
self.gfpgan = GFPGANv1Clean(
|
||||
out_size=512,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=channel_multiplier,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=False,
|
||||
num_mlp=8,
|
||||
input_is_latent=True,
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True)
|
||||
elif arch == 'bilinear':
|
||||
self.gfpgan = GFPGANBilinear(
|
||||
out_size=512,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=channel_multiplier,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=False,
|
||||
num_mlp=8,
|
||||
input_is_latent=True,
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True)
|
||||
elif arch == 'original':
|
||||
self.gfpgan = GFPGANv1(
|
||||
out_size=512,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=channel_multiplier,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
num_mlp=8,
|
||||
input_is_latent=True,
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True)
|
||||
elif arch == 'RestoreFormer':
|
||||
from gfpgan.archs.restoreformer_arch import RestoreFormer
|
||||
self.gfpgan = RestoreFormer()
|
||||
# initialize face helper
|
||||
self.face_helper = FaceRestoreHelper(
|
||||
upscale,
|
||||
face_size=512,
|
||||
crop_ratio=(1, 1),
|
||||
det_model='retinaface_resnet50',
|
||||
save_ext='png',
|
||||
use_parse=True,
|
||||
device=self.device,
|
||||
model_rootpath=model_dir)
|
||||
|
||||
if model_path.startswith('https://'):
|
||||
model_path = load_file_from_url(
|
||||
url=model_path, model_dir=model_dir, progress=True, file_name=None)
|
||||
loadnet = torch.load(model_path)
|
||||
if 'params_ema' in loadnet:
|
||||
keyname = 'params_ema'
|
||||
else:
|
||||
keyname = 'params'
|
||||
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
|
||||
self.gfpgan.eval()
|
||||
self.gfpgan = self.gfpgan.to(self.device)
|
||||
|
||||
@torch.no_grad()
|
||||
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True, weight=0.5):
|
||||
self.face_helper.clean_all()
|
||||
|
||||
if has_aligned: # the inputs are already aligned
|
||||
img = cv2.resize(img, (512, 512))
|
||||
self.face_helper.cropped_faces = [img]
|
||||
else:
|
||||
self.face_helper.read_image(img)
|
||||
# get face landmarks for each face
|
||||
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
|
||||
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
|
||||
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
|
||||
# align and warp each face
|
||||
self.face_helper.align_warp_face()
|
||||
|
||||
# face restoration
|
||||
for cropped_face in self.face_helper.cropped_faces:
|
||||
# prepare data
|
||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
|
||||
|
||||
try:
|
||||
output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0]
|
||||
# convert to image
|
||||
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
|
||||
except RuntimeError as error:
|
||||
print(f'\tFailed inference for GFPGAN: {error}.')
|
||||
restored_face = cropped_face
|
||||
|
||||
restored_face = restored_face.astype('uint8')
|
||||
self.face_helper.add_restored_face(restored_face)
|
||||
|
||||
if not has_aligned and paste_back:
|
||||
# upsample the background
|
||||
if self.bg_upsampler is not None:
|
||||
# Now only support RealESRGAN for upsampling background
|
||||
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
|
||||
else:
|
||||
bg_img = None
|
||||
|
||||
self.face_helper.get_inverse_affine(None)
|
||||
# paste each restored face to the input image
|
||||
restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
|
||||
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
|
||||
else:
|
||||
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
|
||||
@ -1,74 +1,45 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
|
||||
from modules import shared, images, modelloader, paths
|
||||
from modules.paths import models_path
|
||||
|
||||
model_dir = "LDSR"
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
cmd_path = None
|
||||
model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
||||
yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
LDSRModelInfo = namedtuple("LDSRModelInfo", ["name", "location", "model", "netscale"])
|
||||
|
||||
ldsr_models = []
|
||||
have_ldsr = False
|
||||
LDSR_obj = None
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.ldsr_model_arch import LDSR
|
||||
from modules import shared
|
||||
from modules.paths import models_path
|
||||
|
||||
|
||||
class UpscalerLDSR(images.Upscaler):
|
||||
def __init__(self, steps):
|
||||
self.steps = steps
|
||||
class UpscalerLDSR(Upscaler):
|
||||
def __init__(self, user_path):
|
||||
self.name = "LDSR"
|
||||
|
||||
def do_upscale(self, img):
|
||||
return upscale_with_ldsr(img)
|
||||
|
||||
|
||||
def setup_model(dirname):
|
||||
global cmd_path
|
||||
global model_path
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
cmd_path = dirname
|
||||
shared.sd_upscalers.append(UpscalerLDSR(100))
|
||||
|
||||
|
||||
def prepare_ldsr():
|
||||
path = paths.paths.get("LDSR", None)
|
||||
if path is None:
|
||||
return
|
||||
global have_ldsr
|
||||
global LDSR_obj
|
||||
try:
|
||||
from LDSR import LDSR
|
||||
model_files = modelloader.load_models(model_path, model_url, cmd_path, dl_name="model.ckpt", ext_filter=[".ckpt"])
|
||||
yaml_files = modelloader.load_models(model_path, yaml_url, cmd_path, dl_name="project.yaml", ext_filter=[".yaml"])
|
||||
if len(model_files) != 0 and len(yaml_files) != 0:
|
||||
model_file = model_files[0]
|
||||
yaml_file = yaml_files[0]
|
||||
have_ldsr = True
|
||||
LDSR_obj = LDSR(model_file, yaml_file)
|
||||
else:
|
||||
return
|
||||
|
||||
except Exception:
|
||||
print("Error importing LDSR:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
have_ldsr = False
|
||||
|
||||
|
||||
def upscale_with_ldsr(image):
|
||||
prepare_ldsr()
|
||||
if not have_ldsr or LDSR_obj is None:
|
||||
return image
|
||||
|
||||
ddim_steps = shared.opts.ldsr_steps
|
||||
pre_scale = shared.opts.ldsr_pre_down
|
||||
post_scale = shared.opts.ldsr_post_down
|
||||
|
||||
image = LDSR_obj.super_resolution(image, ddim_steps, pre_scale, post_scale)
|
||||
return image
|
||||
self.model_path = os.path.join(models_path, self.name)
|
||||
self.user_path = user_path
|
||||
self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
||||
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
||||
super().__init__()
|
||||
scaler_data = UpscalerData("LDSR", None, self)
|
||||
self.scalers = [scaler_data]
|
||||
|
||||
def load_model(self, path: str):
|
||||
model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
||||
file_name="model.pth", progress=True)
|
||||
yaml = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
||||
file_name="project.yaml", progress=True)
|
||||
|
||||
try:
|
||||
return LDSR(model, yaml)
|
||||
|
||||
except Exception:
|
||||
print("Error importing LDSR:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return None
|
||||
|
||||
def do_upscale(self, img, path):
|
||||
ldsr = self.load_model(path)
|
||||
if ldsr is None:
|
||||
print("NO LDSR!")
|
||||
return img
|
||||
ddim_steps = shared.opts.ldsr_steps
|
||||
pre_scale = shared.opts.ldsr_pre_down
|
||||
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
||||
|
||||
@ -0,0 +1,223 @@
|
||||
import gc
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from einops import rearrange, repeat
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.util import instantiate_from_config, ismap
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
|
||||
# Create LDSR Class
|
||||
class LDSR:
|
||||
def load_model_from_config(self, half_attention):
|
||||
print(f"Loading model from {self.modelPath}")
|
||||
pl_sd = torch.load(self.modelPath, map_location="cpu")
|
||||
sd = pl_sd["state_dict"]
|
||||
config = OmegaConf.load(self.yamlPath)
|
||||
model = instantiate_from_config(config.model)
|
||||
model.load_state_dict(sd, strict=False)
|
||||
model.cuda()
|
||||
if half_attention:
|
||||
model = model.half()
|
||||
|
||||
model.eval()
|
||||
return {"model": model}
|
||||
|
||||
def __init__(self, model_path, yaml_path):
|
||||
self.modelPath = model_path
|
||||
self.yamlPath = yaml_path
|
||||
|
||||
@staticmethod
|
||||
def run(model, selected_path, custom_steps, eta):
|
||||
example = get_cond(selected_path)
|
||||
|
||||
n_runs = 1
|
||||
guider = None
|
||||
ckwargs = None
|
||||
ddim_use_x0_pred = False
|
||||
temperature = 1.
|
||||
eta = eta
|
||||
custom_shape = None
|
||||
|
||||
height, width = example["image"].shape[1:3]
|
||||
split_input = height >= 128 and width >= 128
|
||||
|
||||
if split_input:
|
||||
ks = 128
|
||||
stride = 64
|
||||
vqf = 4 #
|
||||
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
|
||||
"vqf": vqf,
|
||||
"patch_distributed_vq": True,
|
||||
"tie_braker": False,
|
||||
"clip_max_weight": 0.5,
|
||||
"clip_min_weight": 0.01,
|
||||
"clip_max_tie_weight": 0.5,
|
||||
"clip_min_tie_weight": 0.01}
|
||||
else:
|
||||
if hasattr(model, "split_input_params"):
|
||||
delattr(model, "split_input_params")
|
||||
|
||||
x_t = None
|
||||
logs = None
|
||||
for n in range(n_runs):
|
||||
if custom_shape is not None:
|
||||
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
||||
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
|
||||
|
||||
logs = make_convolutional_sample(example, model,
|
||||
custom_steps=custom_steps,
|
||||
eta=eta, quantize_x0=False,
|
||||
custom_shape=custom_shape,
|
||||
temperature=temperature, noise_dropout=0.,
|
||||
corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
|
||||
ddim_use_x0_pred=ddim_use_x0_pred
|
||||
)
|
||||
return logs
|
||||
|
||||
def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
|
||||
model = self.load_model_from_config(half_attention)
|
||||
|
||||
# Run settings
|
||||
diffusion_steps = int(steps)
|
||||
eta = 1.0
|
||||
|
||||
down_sample_method = 'Lanczos'
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
im_og = image
|
||||
width_og, height_og = im_og.size
|
||||
# If we can adjust the max upscale size, then the 4 below should be our variable
|
||||
print("Foo")
|
||||
down_sample_rate = target_scale / 4
|
||||
print(f"Downsample rate is {down_sample_rate}")
|
||||
width_downsampled_pre = width_og * down_sample_rate
|
||||
height_downsampled_pre = height_og * down_sample_method
|
||||
|
||||
if down_sample_rate != 1:
|
||||
print(
|
||||
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
|
||||
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
||||
else:
|
||||
print(f"Down sample rate is 1 from {target_scale} / 4")
|
||||
logs = self.run(model["model"], im_og, diffusion_steps, eta)
|
||||
|
||||
sample = logs["sample"]
|
||||
sample = sample.detach().cpu()
|
||||
sample = torch.clamp(sample, -1., 1.)
|
||||
sample = (sample + 1.) / 2. * 255
|
||||
sample = sample.numpy().astype(np.uint8)
|
||||
sample = np.transpose(sample, (0, 2, 3, 1))
|
||||
a = Image.fromarray(sample[0])
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
print(f'Processing finished!')
|
||||
return a
|
||||
|
||||
|
||||
def get_cond(selected_path):
|
||||
example = dict()
|
||||
up_f = 4
|
||||
c = selected_path.convert('RGB')
|
||||
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
||||
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
|
||||
antialias=True)
|
||||
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
|
||||
c = rearrange(c, '1 c h w -> 1 h w c')
|
||||
c = 2. * c - 1.
|
||||
|
||||
c = c.to(torch.device("cuda"))
|
||||
example["LR_image"] = c
|
||||
example["image"] = c_up
|
||||
|
||||
return example
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
|
||||
mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
|
||||
corrector_kwargs=None, x_t=None
|
||||
):
|
||||
ddim = DDIMSampler(model)
|
||||
bs = shape[0]
|
||||
shape = shape[1:]
|
||||
print(f"Sampling with eta = {eta}; steps: {steps}")
|
||||
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
|
||||
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
|
||||
mask=mask, x0=x0, temperature=temperature, verbose=False,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs, x_t=x_t)
|
||||
|
||||
return samples, intermediates
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
|
||||
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
|
||||
log = dict()
|
||||
|
||||
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
|
||||
return_first_stage_outputs=True,
|
||||
force_c_encode=not (hasattr(model, 'split_input_params')
|
||||
and model.cond_stage_key == 'coordinates_bbox'),
|
||||
return_original_cond=True)
|
||||
|
||||
if custom_shape is not None:
|
||||
z = torch.randn(custom_shape)
|
||||
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
|
||||
|
||||
z0 = None
|
||||
|
||||
log["input"] = x
|
||||
log["reconstruction"] = xrec
|
||||
|
||||
if ismap(xc):
|
||||
log["original_conditioning"] = model.to_rgb(xc)
|
||||
if hasattr(model, 'cond_stage_key'):
|
||||
log[model.cond_stage_key] = model.to_rgb(xc)
|
||||
|
||||
else:
|
||||
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
|
||||
if model.cond_stage_model:
|
||||
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
|
||||
if model.cond_stage_key == 'class_label':
|
||||
log[model.cond_stage_key] = xc[model.cond_stage_key]
|
||||
|
||||
with model.ema_scope("Plotting"):
|
||||
t0 = time.time()
|
||||
|
||||
sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
|
||||
eta=eta,
|
||||
quantize_x0=quantize_x0, mask=None, x0=z0,
|
||||
temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
|
||||
x_t=x_T)
|
||||
t1 = time.time()
|
||||
|
||||
if ddim_use_x0_pred:
|
||||
sample = intermediates['pred_x0'][-1]
|
||||
|
||||
x_sample = model.decode_first_stage(sample)
|
||||
|
||||
try:
|
||||
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
||||
log["sample_noquant"] = x_sample_noquant
|
||||
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
||||
except:
|
||||
pass
|
||||
|
||||
log["sample"] = x_sample
|
||||
log["time"] = t1 - t0
|
||||
|
||||
return log
|
||||
@ -0,0 +1,121 @@
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
|
||||
import PIL
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import modules.shared
|
||||
from modules import modelloader, shared
|
||||
|
||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||
from modules.paths import models_path
|
||||
|
||||
|
||||
class Upscaler:
|
||||
name = None
|
||||
model_path = None
|
||||
model_name = None
|
||||
model_url = None
|
||||
enable = True
|
||||
filter = None
|
||||
model = None
|
||||
user_path = None
|
||||
scalers: []
|
||||
tile = True
|
||||
|
||||
def __init__(self, create_dirs=False):
|
||||
self.mod_pad_h = None
|
||||
self.tile_size = modules.shared.opts.ESRGAN_tile
|
||||
self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
|
||||
self.device = modules.shared.device
|
||||
self.img = None
|
||||
self.output = None
|
||||
self.scale = 1
|
||||
self.half = not modules.shared.cmd_opts.no_half
|
||||
self.pre_pad = 0
|
||||
self.mod_scale = None
|
||||
if self.name is not None and create_dirs:
|
||||
self.model_path = os.path.join(models_path, self.name)
|
||||
if not os.path.exists(self.model_path):
|
||||
os.makedirs(self.model_path)
|
||||
|
||||
try:
|
||||
import cv2
|
||||
self.can_tile = True
|
||||
except:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_upscale(self, img: PIL.Image, selected_model: str):
|
||||
return img
|
||||
|
||||
def upscale(self, img: PIL.Image, scale: int, selected_model: str = None):
|
||||
self.scale = scale
|
||||
dest_w = img.width * scale
|
||||
dest_h = img.height * scale
|
||||
for i in range(3):
|
||||
if img.width >= dest_w and img.height >= dest_h:
|
||||
break
|
||||
img = self.do_upscale(img, selected_model)
|
||||
if img.width != dest_w or img.height != dest_h:
|
||||
img = img.resize(dest_w, dest_h, resample=LANCZOS)
|
||||
|
||||
return img
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, path: str):
|
||||
pass
|
||||
|
||||
def find_models(self, ext_filter=None) -> list:
|
||||
return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path)
|
||||
|
||||
def update_status(self, prompt):
|
||||
print(f"\nextras: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
|
||||
class UpscalerData:
|
||||
name = None
|
||||
data_path = None
|
||||
scale: int = 4
|
||||
scaler: Upscaler = None
|
||||
model: None
|
||||
|
||||
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
|
||||
self.name = name
|
||||
self.data_path = path
|
||||
self.scaler = upscaler
|
||||
self.scale = scale
|
||||
self.model = model
|
||||
|
||||
|
||||
class UpscalerNone(Upscaler):
|
||||
name = "None"
|
||||
scalers = []
|
||||
|
||||
def load_model(self, path):
|
||||
pass
|
||||
|
||||
def do_upscale(self, img, selected_model=None):
|
||||
return img
|
||||
|
||||
def __init__(self, dirname=None):
|
||||
super().__init__(False)
|
||||
self.scalers = [UpscalerData("None", None, self)]
|
||||
|
||||
|
||||
class UpscalerLanczos(Upscaler):
|
||||
scalers = []
|
||||
|
||||
def do_upscale(self, img, selected_model=None):
|
||||
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
|
||||
|
||||
def load_model(self, _):
|
||||
pass
|
||||
|
||||
def __init__(self, dirname=None):
|
||||
super().__init__(False)
|
||||
self.name = "Lanczos"
|
||||
self.scalers = [UpscalerData("Lanczos", None, self)]
|
||||
|
||||
Loading…
Reference in New Issue