Merge branch 'AUTOMATIC1111:master' into master
commit
25de9df364
@ -1 +1,13 @@
|
|||||||
* @AUTOMATIC1111
|
* @AUTOMATIC1111
|
||||||
|
/localizations/ar_AR.json @xmodar @blackneoo
|
||||||
|
/localizations/de_DE.json @LunixWasTaken
|
||||||
|
/localizations/es_ES.json @innovaciones
|
||||||
|
/localizations/fr_FR.json @tumbly
|
||||||
|
/localizations/it_IT.json @EugenioBuffo
|
||||||
|
/localizations/ja_JP.json @yuuki76
|
||||||
|
/localizations/ko_KR.json @36DB
|
||||||
|
/localizations/pt_BR.json @M-art-ucci
|
||||||
|
/localizations/ru_RU.json @kabachuha
|
||||||
|
/localizations/tr_TR.json @camenduru
|
||||||
|
/localizations/zh_CN.json @dtlnor @bgluminous
|
||||||
|
/localizations/zh_TW.json @benlisquare
|
||||||
|
|||||||
@ -0,0 +1,24 @@
|
|||||||
|
|
||||||
|
function extensions_apply(_, _){
|
||||||
|
disable = []
|
||||||
|
update = []
|
||||||
|
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
||||||
|
if(x.name.startsWith("enable_") && ! x.checked)
|
||||||
|
disable.push(x.name.substr(7))
|
||||||
|
|
||||||
|
if(x.name.startsWith("update_") && x.checked)
|
||||||
|
update.push(x.name.substr(7))
|
||||||
|
})
|
||||||
|
|
||||||
|
restart_reload()
|
||||||
|
|
||||||
|
return [JSON.stringify(disable), JSON.stringify(update)]
|
||||||
|
}
|
||||||
|
|
||||||
|
function extensions_check(){
|
||||||
|
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
||||||
|
x.innerHTML = "Loading..."
|
||||||
|
})
|
||||||
|
|
||||||
|
return []
|
||||||
|
}
|
||||||
@ -1,206 +0,0 @@
|
|||||||
var images_history_click_image = function(){
|
|
||||||
if (!this.classList.contains("transform")){
|
|
||||||
var gallery = images_history_get_parent_by_class(this, "images_history_cantainor");
|
|
||||||
var buttons = gallery.querySelectorAll(".gallery-item");
|
|
||||||
var i = 0;
|
|
||||||
var hidden_list = [];
|
|
||||||
buttons.forEach(function(e){
|
|
||||||
if (e.style.display == "none"){
|
|
||||||
hidden_list.push(i);
|
|
||||||
}
|
|
||||||
i += 1;
|
|
||||||
})
|
|
||||||
if (hidden_list.length > 0){
|
|
||||||
setTimeout(images_history_hide_buttons, 10, hidden_list, gallery);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
images_history_set_image_info(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
var images_history_click_tab = function(){
|
|
||||||
var tabs_box = gradioApp().getElementById("images_history_tab");
|
|
||||||
if (!tabs_box.classList.contains(this.getAttribute("tabname"))) {
|
|
||||||
gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_renew_page").click();
|
|
||||||
tabs_box.classList.add(this.getAttribute("tabname"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function images_history_disabled_del(){
|
|
||||||
gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
|
|
||||||
btn.setAttribute('disabled','disabled');
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
function images_history_get_parent_by_class(item, class_name){
|
|
||||||
var parent = item.parentElement;
|
|
||||||
while(!parent.classList.contains(class_name)){
|
|
||||||
parent = parent.parentElement;
|
|
||||||
}
|
|
||||||
return parent;
|
|
||||||
}
|
|
||||||
|
|
||||||
function images_history_get_parent_by_tagname(item, tagname){
|
|
||||||
var parent = item.parentElement;
|
|
||||||
tagname = tagname.toUpperCase()
|
|
||||||
while(parent.tagName != tagname){
|
|
||||||
console.log(parent.tagName, tagname)
|
|
||||||
parent = parent.parentElement;
|
|
||||||
}
|
|
||||||
return parent;
|
|
||||||
}
|
|
||||||
|
|
||||||
function images_history_hide_buttons(hidden_list, gallery){
|
|
||||||
var buttons = gallery.querySelectorAll(".gallery-item");
|
|
||||||
var num = 0;
|
|
||||||
buttons.forEach(function(e){
|
|
||||||
if (e.style.display == "none"){
|
|
||||||
num += 1;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
if (num == hidden_list.length){
|
|
||||||
setTimeout(images_history_hide_buttons, 10, hidden_list, gallery);
|
|
||||||
}
|
|
||||||
for( i in hidden_list){
|
|
||||||
buttons[hidden_list[i]].style.display = "none";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function images_history_set_image_info(button){
|
|
||||||
var buttons = images_history_get_parent_by_tagname(button, "DIV").querySelectorAll(".gallery-item");
|
|
||||||
var index = -1;
|
|
||||||
var i = 0;
|
|
||||||
buttons.forEach(function(e){
|
|
||||||
if(e == button){
|
|
||||||
index = i;
|
|
||||||
}
|
|
||||||
if(e.style.display != "none"){
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
var gallery = images_history_get_parent_by_class(button, "images_history_cantainor");
|
|
||||||
var set_btn = gallery.querySelector(".images_history_set_index");
|
|
||||||
var curr_idx = set_btn.getAttribute("img_index", index);
|
|
||||||
if (curr_idx != index) {
|
|
||||||
set_btn.setAttribute("img_index", index);
|
|
||||||
images_history_disabled_del();
|
|
||||||
}
|
|
||||||
set_btn.click();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
function images_history_get_current_img(tabname, image_path, files){
|
|
||||||
return [
|
|
||||||
gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"),
|
|
||||||
image_path,
|
|
||||||
files
|
|
||||||
];
|
|
||||||
}
|
|
||||||
|
|
||||||
function images_history_delete(del_num, tabname, img_path, img_file_name, page_index, filenames, image_index){
|
|
||||||
image_index = parseInt(image_index);
|
|
||||||
var tab = gradioApp().getElementById(tabname + '_images_history');
|
|
||||||
var set_btn = tab.querySelector(".images_history_set_index");
|
|
||||||
var buttons = [];
|
|
||||||
tab.querySelectorAll(".gallery-item").forEach(function(e){
|
|
||||||
if (e.style.display != 'none'){
|
|
||||||
buttons.push(e);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
var img_num = buttons.length / 2;
|
|
||||||
if (img_num <= del_num){
|
|
||||||
setTimeout(function(tabname){
|
|
||||||
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
|
|
||||||
}, 30, tabname);
|
|
||||||
} else {
|
|
||||||
var next_img
|
|
||||||
for (var i = 0; i < del_num; i++){
|
|
||||||
if (image_index + i < image_index + img_num){
|
|
||||||
buttons[image_index + i].style.display = 'none';
|
|
||||||
buttons[image_index + img_num + 1].style.display = 'none';
|
|
||||||
next_img = image_index + i + 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var bnt;
|
|
||||||
if (next_img >= img_num){
|
|
||||||
btn = buttons[image_index - del_num];
|
|
||||||
} else {
|
|
||||||
btn = buttons[next_img];
|
|
||||||
}
|
|
||||||
setTimeout(function(btn){btn.click()}, 30, btn);
|
|
||||||
}
|
|
||||||
images_history_disabled_del();
|
|
||||||
return [del_num, tabname, img_path, img_file_name, page_index, filenames, image_index];
|
|
||||||
}
|
|
||||||
|
|
||||||
function images_history_turnpage(img_path, page_index, image_index, tabname){
|
|
||||||
var buttons = gradioApp().getElementById(tabname + '_images_history').querySelectorAll(".gallery-item");
|
|
||||||
buttons.forEach(function(elem) {
|
|
||||||
elem.style.display = 'block';
|
|
||||||
})
|
|
||||||
return [img_path, page_index, image_index, tabname];
|
|
||||||
}
|
|
||||||
|
|
||||||
function images_history_enable_del_buttons(){
|
|
||||||
gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
|
|
||||||
btn.removeAttribute('disabled');
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
function images_history_init(){
|
|
||||||
var load_txt2img_button = gradioApp().getElementById('txt2img_images_history_renew_page')
|
|
||||||
if (load_txt2img_button){
|
|
||||||
for (var i in images_history_tab_list ){
|
|
||||||
tab = images_history_tab_list[i];
|
|
||||||
gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor");
|
|
||||||
gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index");
|
|
||||||
gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button");
|
|
||||||
gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery");
|
|
||||||
|
|
||||||
}
|
|
||||||
var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div");
|
|
||||||
tabs_box.setAttribute("id", "images_history_tab");
|
|
||||||
var tab_btns = tabs_box.querySelectorAll("button");
|
|
||||||
for (var i in images_history_tab_list){
|
|
||||||
var tabname = images_history_tab_list[i]
|
|
||||||
tab_btns[i].setAttribute("tabname", tabname);
|
|
||||||
|
|
||||||
// this refreshes history upon tab switch
|
|
||||||
// until the history is known to work well, which is not the case now, we do not do this at startup
|
|
||||||
//tab_btns[i].addEventListener('click', images_history_click_tab);
|
|
||||||
}
|
|
||||||
tabs_box.classList.add(images_history_tab_list[0]);
|
|
||||||
|
|
||||||
// same as above, at page load
|
|
||||||
//load_txt2img_button.click();
|
|
||||||
} else {
|
|
||||||
setTimeout(images_history_init, 500);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var images_history_tab_list = ["txt2img", "img2img", "extras"];
|
|
||||||
setTimeout(images_history_init, 500);
|
|
||||||
document.addEventListener("DOMContentLoaded", function() {
|
|
||||||
var mutationObserver = new MutationObserver(function(m){
|
|
||||||
for (var i in images_history_tab_list ){
|
|
||||||
let tabname = images_history_tab_list[i]
|
|
||||||
var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item');
|
|
||||||
buttons.forEach(function(bnt){
|
|
||||||
bnt.addEventListener('click', images_history_click_image, true);
|
|
||||||
});
|
|
||||||
|
|
||||||
// same as load_txt2img_button.click() above
|
|
||||||
/*
|
|
||||||
var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg");
|
|
||||||
if (cls_btn){
|
|
||||||
cls_btn.addEventListener('click', function(){
|
|
||||||
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
|
|
||||||
}, false);
|
|
||||||
}*/
|
|
||||||
|
|
||||||
}
|
|
||||||
});
|
|
||||||
mutationObserver.observe( gradioApp(), { childList:true, subtree:true });
|
|
||||||
|
|
||||||
});
|
|
||||||
|
|
||||||
|
|
||||||
File diff suppressed because it is too large
Load Diff
@ -1,241 +0,0 @@
|
|||||||
import copy
|
|
||||||
import itertools
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import html
|
|
||||||
import gc
|
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from torch import optim
|
|
||||||
|
|
||||||
from modules import shared
|
|
||||||
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer
|
|
||||||
from tqdm.auto import tqdm, trange
|
|
||||||
from modules.shared import opts, device
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_images_in_folder(folder):
|
|
||||||
return [os.path.join(folder, f) for f in os.listdir(folder) if
|
|
||||||
os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)]
|
|
||||||
|
|
||||||
|
|
||||||
def check_is_valid_image_file(filename):
|
|
||||||
return filename.lower().endswith(('.png', '.jpg', '.jpeg', ".gif", ".tiff", ".webp"))
|
|
||||||
|
|
||||||
|
|
||||||
def batched(dataset, total, n=1):
|
|
||||||
for ndx in range(0, total, n):
|
|
||||||
yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]
|
|
||||||
|
|
||||||
|
|
||||||
def iter_to_batched(iterable, n=1):
|
|
||||||
it = iter(iterable)
|
|
||||||
while True:
|
|
||||||
chunk = tuple(itertools.islice(it, n))
|
|
||||||
if not chunk:
|
|
||||||
return
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
|
|
||||||
def create_ui():
|
|
||||||
import modules.ui
|
|
||||||
|
|
||||||
with gr.Group():
|
|
||||||
with gr.Accordion("Open for Clip Aesthetic!", open=False):
|
|
||||||
with gr.Row():
|
|
||||||
aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight",
|
|
||||||
value=0.9)
|
|
||||||
aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
aesthetic_lr = gr.Textbox(label='Aesthetic learning rate',
|
|
||||||
placeholder="Aesthetic learning rate", value="0.0001")
|
|
||||||
aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
|
|
||||||
aesthetic_imgs = gr.Dropdown(sorted(shared.aesthetic_embeddings.keys()),
|
|
||||||
label="Aesthetic imgs embedding",
|
|
||||||
value="None")
|
|
||||||
|
|
||||||
modules.ui.create_refresh_button(aesthetic_imgs, shared.update_aesthetic_embeddings, lambda: {"choices": sorted(shared.aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings")
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs',
|
|
||||||
placeholder="This text is used to rotate the feature space of the imgs embs",
|
|
||||||
value="")
|
|
||||||
aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01,
|
|
||||||
value=0.1)
|
|
||||||
aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
|
|
||||||
|
|
||||||
return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative
|
|
||||||
|
|
||||||
|
|
||||||
aesthetic_clip_model = None
|
|
||||||
|
|
||||||
|
|
||||||
def aesthetic_clip():
|
|
||||||
global aesthetic_clip_model
|
|
||||||
|
|
||||||
if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path:
|
|
||||||
aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path)
|
|
||||||
aesthetic_clip_model.cpu()
|
|
||||||
|
|
||||||
return aesthetic_clip_model
|
|
||||||
|
|
||||||
|
|
||||||
def generate_imgs_embd(name, folder, batch_size):
|
|
||||||
model = aesthetic_clip().to(device)
|
|
||||||
processor = CLIPProcessor.from_pretrained(model.name_or_path)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
embs = []
|
|
||||||
for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size),
|
|
||||||
desc=f"Generating embeddings for {name}"):
|
|
||||||
if shared.state.interrupted:
|
|
||||||
break
|
|
||||||
inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device)
|
|
||||||
outputs = model.get_image_features(**inputs).cpu()
|
|
||||||
embs.append(torch.clone(outputs))
|
|
||||||
inputs.to("cpu")
|
|
||||||
del inputs, outputs
|
|
||||||
|
|
||||||
embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)
|
|
||||||
|
|
||||||
# The generated embedding will be located here
|
|
||||||
path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
|
|
||||||
torch.save(embs, path)
|
|
||||||
|
|
||||||
model.cpu()
|
|
||||||
del processor
|
|
||||||
del embs
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
res = f"""
|
|
||||||
Done generating embedding for {name}!
|
|
||||||
Aesthetic embedding saved to {html.escape(path)}
|
|
||||||
"""
|
|
||||||
shared.update_aesthetic_embeddings()
|
|
||||||
return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding",
|
|
||||||
value="None"), \
|
|
||||||
gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()),
|
|
||||||
label="Imgs embedding",
|
|
||||||
value="None"), res, ""
|
|
||||||
|
|
||||||
|
|
||||||
def slerp(low, high, val):
|
|
||||||
low_norm = low / torch.norm(low, dim=1, keepdim=True)
|
|
||||||
high_norm = high / torch.norm(high, dim=1, keepdim=True)
|
|
||||||
omega = torch.acos((low_norm * high_norm).sum(1))
|
|
||||||
so = torch.sin(omega)
|
|
||||||
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
class AestheticCLIP:
|
|
||||||
def __init__(self):
|
|
||||||
self.skip = False
|
|
||||||
self.aesthetic_steps = 0
|
|
||||||
self.aesthetic_weight = 0
|
|
||||||
self.aesthetic_lr = 0
|
|
||||||
self.slerp = False
|
|
||||||
self.aesthetic_text_negative = ""
|
|
||||||
self.aesthetic_slerp_angle = 0
|
|
||||||
self.aesthetic_imgs_text = ""
|
|
||||||
|
|
||||||
self.image_embs_name = None
|
|
||||||
self.image_embs = None
|
|
||||||
self.load_image_embs(None)
|
|
||||||
|
|
||||||
def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
|
|
||||||
aesthetic_slerp=True, aesthetic_imgs_text="",
|
|
||||||
aesthetic_slerp_angle=0.15,
|
|
||||||
aesthetic_text_negative=False):
|
|
||||||
self.aesthetic_imgs_text = aesthetic_imgs_text
|
|
||||||
self.aesthetic_slerp_angle = aesthetic_slerp_angle
|
|
||||||
self.aesthetic_text_negative = aesthetic_text_negative
|
|
||||||
self.slerp = aesthetic_slerp
|
|
||||||
self.aesthetic_lr = aesthetic_lr
|
|
||||||
self.aesthetic_weight = aesthetic_weight
|
|
||||||
self.aesthetic_steps = aesthetic_steps
|
|
||||||
self.load_image_embs(image_embs_name)
|
|
||||||
|
|
||||||
if self.image_embs_name is not None:
|
|
||||||
p.extra_generation_params.update({
|
|
||||||
"Aesthetic LR": aesthetic_lr,
|
|
||||||
"Aesthetic weight": aesthetic_weight,
|
|
||||||
"Aesthetic steps": aesthetic_steps,
|
|
||||||
"Aesthetic embedding": self.image_embs_name,
|
|
||||||
"Aesthetic slerp": aesthetic_slerp,
|
|
||||||
"Aesthetic text": aesthetic_imgs_text,
|
|
||||||
"Aesthetic text negative": aesthetic_text_negative,
|
|
||||||
"Aesthetic slerp angle": aesthetic_slerp_angle,
|
|
||||||
})
|
|
||||||
|
|
||||||
def set_skip(self, skip):
|
|
||||||
self.skip = skip
|
|
||||||
|
|
||||||
def load_image_embs(self, image_embs_name):
|
|
||||||
if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None":
|
|
||||||
image_embs_name = None
|
|
||||||
self.image_embs_name = None
|
|
||||||
if image_embs_name is not None and self.image_embs_name != image_embs_name:
|
|
||||||
self.image_embs_name = image_embs_name
|
|
||||||
self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device)
|
|
||||||
self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
|
|
||||||
self.image_embs.requires_grad_(False)
|
|
||||||
|
|
||||||
def __call__(self, z, remade_batch_tokens):
|
|
||||||
if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None:
|
|
||||||
tokenizer = shared.sd_model.cond_stage_model.tokenizer
|
|
||||||
if not opts.use_old_emphasis_implementation:
|
|
||||||
remade_batch_tokens = [
|
|
||||||
[tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in
|
|
||||||
remade_batch_tokens]
|
|
||||||
|
|
||||||
tokens = torch.asarray(remade_batch_tokens).to(device)
|
|
||||||
|
|
||||||
model = copy.deepcopy(aesthetic_clip()).to(device)
|
|
||||||
model.requires_grad_(True)
|
|
||||||
if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
|
|
||||||
text_embs_2 = model.get_text_features(
|
|
||||||
**tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device))
|
|
||||||
if self.aesthetic_text_negative:
|
|
||||||
text_embs_2 = self.image_embs - text_embs_2
|
|
||||||
text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True)
|
|
||||||
img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle)
|
|
||||||
else:
|
|
||||||
img_embs = self.image_embs
|
|
||||||
|
|
||||||
with torch.enable_grad():
|
|
||||||
|
|
||||||
# We optimize the model to maximize the similarity
|
|
||||||
optimizer = optim.Adam(
|
|
||||||
model.text_model.parameters(), lr=self.aesthetic_lr
|
|
||||||
)
|
|
||||||
|
|
||||||
for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"):
|
|
||||||
text_embs = model.get_text_features(input_ids=tokens)
|
|
||||||
text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
|
|
||||||
sim = text_embs @ img_embs.T
|
|
||||||
loss = -sim
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.mean().backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
|
||||||
if opts.CLIP_stop_at_last_layers > 1:
|
|
||||||
zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers]
|
|
||||||
zn = model.text_model.final_layer_norm(zn)
|
|
||||||
else:
|
|
||||||
zn = zn.last_hidden_state
|
|
||||||
model.cpu()
|
|
||||||
del model
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1)
|
|
||||||
if self.slerp:
|
|
||||||
z = slerp(z, zn, self.aesthetic_weight)
|
|
||||||
else:
|
|
||||||
z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight
|
|
||||||
|
|
||||||
return z
|
|
||||||
@ -0,0 +1,167 @@
|
|||||||
|
import inspect
|
||||||
|
from click import prompt
|
||||||
|
from pydantic import BaseModel, Field, create_model
|
||||||
|
from typing import Any, Optional
|
||||||
|
from typing_extensions import Literal
|
||||||
|
from inflection import underscore
|
||||||
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
||||||
|
from modules.shared import sd_upscalers
|
||||||
|
|
||||||
|
API_NOT_ALLOWED = [
|
||||||
|
"self",
|
||||||
|
"kwargs",
|
||||||
|
"sd_model",
|
||||||
|
"outpath_samples",
|
||||||
|
"outpath_grids",
|
||||||
|
"sampler_index",
|
||||||
|
"do_not_save_samples",
|
||||||
|
"do_not_save_grid",
|
||||||
|
"extra_generation_params",
|
||||||
|
"overlay_images",
|
||||||
|
"do_not_reload_embeddings",
|
||||||
|
"seed_enable_extras",
|
||||||
|
"prompt_for_display",
|
||||||
|
"sampler_noise_scheduler_override",
|
||||||
|
"ddim_discretize"
|
||||||
|
]
|
||||||
|
|
||||||
|
class ModelDef(BaseModel):
|
||||||
|
"""Assistance Class for Pydantic Dynamic Model Generation"""
|
||||||
|
|
||||||
|
field: str
|
||||||
|
field_alias: str
|
||||||
|
field_type: Any
|
||||||
|
field_value: Any
|
||||||
|
field_exclude: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class PydanticModelGenerator:
|
||||||
|
"""
|
||||||
|
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
|
||||||
|
source_data is a snapshot of the default values produced by the class
|
||||||
|
params are the names of the actual keys required by __init__
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = None,
|
||||||
|
class_instance = None,
|
||||||
|
additional_fields = None,
|
||||||
|
):
|
||||||
|
def field_type_generator(k, v):
|
||||||
|
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
||||||
|
# print(k, v.annotation, v.default)
|
||||||
|
field_type = v.annotation
|
||||||
|
|
||||||
|
return Optional[field_type]
|
||||||
|
|
||||||
|
def merge_class_params(class_):
|
||||||
|
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
|
||||||
|
parameters = {}
|
||||||
|
for classes in all_classes:
|
||||||
|
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
||||||
|
return parameters
|
||||||
|
|
||||||
|
|
||||||
|
self._model_name = model_name
|
||||||
|
self._class_data = merge_class_params(class_instance)
|
||||||
|
self._model_def = [
|
||||||
|
ModelDef(
|
||||||
|
field=underscore(k),
|
||||||
|
field_alias=k,
|
||||||
|
field_type=field_type_generator(k, v),
|
||||||
|
field_value=v.default
|
||||||
|
)
|
||||||
|
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||||
|
]
|
||||||
|
|
||||||
|
for fields in additional_fields:
|
||||||
|
self._model_def.append(ModelDef(
|
||||||
|
field=underscore(fields["key"]),
|
||||||
|
field_alias=fields["key"],
|
||||||
|
field_type=fields["type"],
|
||||||
|
field_value=fields["default"],
|
||||||
|
field_exclude=fields["exclude"] if "exclude" in fields else False))
|
||||||
|
|
||||||
|
def generate_model(self):
|
||||||
|
"""
|
||||||
|
Creates a pydantic BaseModel
|
||||||
|
from the json and overrides provided at initialization
|
||||||
|
"""
|
||||||
|
fields = {
|
||||||
|
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
|
||||||
|
}
|
||||||
|
DynamicModel = create_model(self._model_name, **fields)
|
||||||
|
DynamicModel.__config__.allow_population_by_field_name = True
|
||||||
|
DynamicModel.__config__.allow_mutation = True
|
||||||
|
return DynamicModel
|
||||||
|
|
||||||
|
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
||||||
|
"StableDiffusionProcessingTxt2Img",
|
||||||
|
StableDiffusionProcessingTxt2Img,
|
||||||
|
[{"key": "sampler_index", "type": str, "default": "Euler"}]
|
||||||
|
).generate_model()
|
||||||
|
|
||||||
|
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
||||||
|
"StableDiffusionProcessingImg2Img",
|
||||||
|
StableDiffusionProcessingImg2Img,
|
||||||
|
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
|
||||||
|
).generate_model()
|
||||||
|
|
||||||
|
class TextToImageResponse(BaseModel):
|
||||||
|
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||||
|
parameters: dict
|
||||||
|
info: str
|
||||||
|
|
||||||
|
class ImageToImageResponse(BaseModel):
|
||||||
|
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||||
|
parameters: dict
|
||||||
|
info: str
|
||||||
|
|
||||||
|
class ExtrasBaseRequest(BaseModel):
|
||||||
|
resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
|
||||||
|
show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
|
||||||
|
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
|
||||||
|
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
|
||||||
|
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
|
||||||
|
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
|
||||||
|
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
|
||||||
|
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
|
||||||
|
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?")
|
||||||
|
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
||||||
|
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
||||||
|
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
|
||||||
|
|
||||||
|
class ExtraBaseResponse(BaseModel):
|
||||||
|
html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
|
||||||
|
|
||||||
|
class ExtrasSingleImageRequest(ExtrasBaseRequest):
|
||||||
|
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||||
|
|
||||||
|
class ExtrasSingleImageResponse(ExtraBaseResponse):
|
||||||
|
image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||||
|
|
||||||
|
class FileData(BaseModel):
|
||||||
|
data: str = Field(title="File data", description="Base64 representation of the file")
|
||||||
|
name: str = Field(title="File name")
|
||||||
|
|
||||||
|
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
|
||||||
|
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
|
||||||
|
|
||||||
|
class ExtrasBatchImagesResponse(ExtraBaseResponse):
|
||||||
|
images: list[str] = Field(title="Images", description="The generated images in base64 format.")
|
||||||
|
|
||||||
|
class PNGInfoRequest(BaseModel):
|
||||||
|
image: str = Field(title="Image", description="The base64 encoded PNG image")
|
||||||
|
|
||||||
|
class PNGInfoResponse(BaseModel):
|
||||||
|
info: str = Field(title="Image info", description="A string with all the info the image had")
|
||||||
|
|
||||||
|
class ProgressRequest(BaseModel):
|
||||||
|
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
|
||||||
|
|
||||||
|
class ProgressResponse(BaseModel):
|
||||||
|
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
|
||||||
|
eta_relative: float = Field(title="ETA in secs")
|
||||||
|
state: dict = Field(title="State", description="The current state snapshot")
|
||||||
|
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
||||||
@ -1,99 +0,0 @@
|
|||||||
from inflection import underscore
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
from pydantic import BaseModel, Field, create_model
|
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
|
|
||||||
API_NOT_ALLOWED = [
|
|
||||||
"self",
|
|
||||||
"kwargs",
|
|
||||||
"sd_model",
|
|
||||||
"outpath_samples",
|
|
||||||
"outpath_grids",
|
|
||||||
"sampler_index",
|
|
||||||
"do_not_save_samples",
|
|
||||||
"do_not_save_grid",
|
|
||||||
"extra_generation_params",
|
|
||||||
"overlay_images",
|
|
||||||
"do_not_reload_embeddings",
|
|
||||||
"seed_enable_extras",
|
|
||||||
"prompt_for_display",
|
|
||||||
"sampler_noise_scheduler_override",
|
|
||||||
"ddim_discretize"
|
|
||||||
]
|
|
||||||
|
|
||||||
class ModelDef(BaseModel):
|
|
||||||
"""Assistance Class for Pydantic Dynamic Model Generation"""
|
|
||||||
|
|
||||||
field: str
|
|
||||||
field_alias: str
|
|
||||||
field_type: Any
|
|
||||||
field_value: Any
|
|
||||||
|
|
||||||
|
|
||||||
class PydanticModelGenerator:
|
|
||||||
"""
|
|
||||||
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
|
|
||||||
source_data is a snapshot of the default values produced by the class
|
|
||||||
params are the names of the actual keys required by __init__
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str = None,
|
|
||||||
class_instance = None,
|
|
||||||
additional_fields = None,
|
|
||||||
):
|
|
||||||
def field_type_generator(k, v):
|
|
||||||
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
|
||||||
# print(k, v.annotation, v.default)
|
|
||||||
field_type = v.annotation
|
|
||||||
|
|
||||||
return Optional[field_type]
|
|
||||||
|
|
||||||
def merge_class_params(class_):
|
|
||||||
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
|
|
||||||
parameters = {}
|
|
||||||
for classes in all_classes:
|
|
||||||
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
|
||||||
return parameters
|
|
||||||
|
|
||||||
|
|
||||||
self._model_name = model_name
|
|
||||||
self._class_data = merge_class_params(class_instance)
|
|
||||||
self._model_def = [
|
|
||||||
ModelDef(
|
|
||||||
field=underscore(k),
|
|
||||||
field_alias=k,
|
|
||||||
field_type=field_type_generator(k, v),
|
|
||||||
field_value=v.default
|
|
||||||
)
|
|
||||||
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
|
||||||
]
|
|
||||||
|
|
||||||
for fields in additional_fields:
|
|
||||||
self._model_def.append(ModelDef(
|
|
||||||
field=underscore(fields["key"]),
|
|
||||||
field_alias=fields["key"],
|
|
||||||
field_type=fields["type"],
|
|
||||||
field_value=fields["default"]))
|
|
||||||
|
|
||||||
def generate_model(self):
|
|
||||||
"""
|
|
||||||
Creates a pydantic BaseModel
|
|
||||||
from the json and overrides provided at initialization
|
|
||||||
"""
|
|
||||||
fields = {
|
|
||||||
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
|
|
||||||
}
|
|
||||||
DynamicModel = create_model(self._model_name, **fields)
|
|
||||||
DynamicModel.__config__.allow_population_by_field_name = True
|
|
||||||
DynamicModel.__config__.allow_mutation = True
|
|
||||||
return DynamicModel
|
|
||||||
|
|
||||||
StableDiffusionProcessingAPI = PydanticModelGenerator(
|
|
||||||
"StableDiffusionProcessingTxt2Img",
|
|
||||||
StableDiffusionProcessingTxt2Img,
|
|
||||||
[{"key": "sampler_index", "type": str, "default": "Euler"}]
|
|
||||||
).generate_model()
|
|
||||||
@ -1,76 +0,0 @@
|
|||||||
import os.path
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import PIL.Image
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
|
|
||||||
import modules.upscaler
|
|
||||||
from modules import devices, modelloader
|
|
||||||
from modules.bsrgan_model_arch import RRDBNet
|
|
||||||
|
|
||||||
|
|
||||||
class UpscalerBSRGAN(modules.upscaler.Upscaler):
|
|
||||||
def __init__(self, dirname):
|
|
||||||
self.name = "BSRGAN"
|
|
||||||
self.model_name = "BSRGAN 4x"
|
|
||||||
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
|
|
||||||
self.user_path = dirname
|
|
||||||
super().__init__()
|
|
||||||
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
|
||||||
scalers = []
|
|
||||||
if len(model_paths) == 0:
|
|
||||||
scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
|
|
||||||
scalers.append(scaler_data)
|
|
||||||
for file in model_paths:
|
|
||||||
if "http" in file:
|
|
||||||
name = self.model_name
|
|
||||||
else:
|
|
||||||
name = modelloader.friendly_name(file)
|
|
||||||
try:
|
|
||||||
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
|
||||||
scalers.append(scaler_data)
|
|
||||||
except Exception:
|
|
||||||
print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
self.scalers = scalers
|
|
||||||
|
|
||||||
def do_upscale(self, img: PIL.Image, selected_file):
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
model = self.load_model(selected_file)
|
|
||||||
if model is None:
|
|
||||||
return img
|
|
||||||
model.to(devices.device_bsrgan)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
img = np.array(img)
|
|
||||||
img = img[:, :, ::-1]
|
|
||||||
img = np.moveaxis(img, 2, 0) / 255
|
|
||||||
img = torch.from_numpy(img).float()
|
|
||||||
img = img.unsqueeze(0).to(devices.device_bsrgan)
|
|
||||||
with torch.no_grad():
|
|
||||||
output = model(img)
|
|
||||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
||||||
output = 255. * np.moveaxis(output, 0, 2)
|
|
||||||
output = output.astype(np.uint8)
|
|
||||||
output = output[:, :, ::-1]
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return PIL.Image.fromarray(output, 'RGB')
|
|
||||||
|
|
||||||
def load_model(self, path: str):
|
|
||||||
if "http" in path:
|
|
||||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
|
||||||
progress=True)
|
|
||||||
else:
|
|
||||||
filename = path
|
|
||||||
if not os.path.exists(filename) or filename is None:
|
|
||||||
print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
|
|
||||||
return None
|
|
||||||
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
|
|
||||||
model.load_state_dict(torch.load(filename), strict=True)
|
|
||||||
model.eval()
|
|
||||||
for k, v in model.named_parameters():
|
|
||||||
v.requires_grad = False
|
|
||||||
return model
|
|
||||||
|
|
||||||
@ -1,102 +0,0 @@
|
|||||||
import functools
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.nn.init as init
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_weights(net_l, scale=1):
|
|
||||||
if not isinstance(net_l, list):
|
|
||||||
net_l = [net_l]
|
|
||||||
for net in net_l:
|
|
||||||
for m in net.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
|
||||||
m.weight.data *= scale # for residual block
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.zero_()
|
|
||||||
elif isinstance(m, nn.Linear):
|
|
||||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
|
||||||
m.weight.data *= scale
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.zero_()
|
|
||||||
elif isinstance(m, nn.BatchNorm2d):
|
|
||||||
init.constant_(m.weight, 1)
|
|
||||||
init.constant_(m.bias.data, 0.0)
|
|
||||||
|
|
||||||
|
|
||||||
def make_layer(block, n_layers):
|
|
||||||
layers = []
|
|
||||||
for _ in range(n_layers):
|
|
||||||
layers.append(block())
|
|
||||||
return nn.Sequential(*layers)
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualDenseBlock_5C(nn.Module):
|
|
||||||
def __init__(self, nf=64, gc=32, bias=True):
|
|
||||||
super(ResidualDenseBlock_5C, self).__init__()
|
|
||||||
# gc: growth channel, i.e. intermediate channels
|
|
||||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
# initialization
|
|
||||||
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x1 = self.lrelu(self.conv1(x))
|
|
||||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
|
||||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
|
||||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
|
||||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
|
||||||
return x5 * 0.2 + x
|
|
||||||
|
|
||||||
|
|
||||||
class RRDB(nn.Module):
|
|
||||||
'''Residual in Residual Dense Block'''
|
|
||||||
|
|
||||||
def __init__(self, nf, gc=32):
|
|
||||||
super(RRDB, self).__init__()
|
|
||||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
|
||||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
|
||||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = self.RDB1(x)
|
|
||||||
out = self.RDB2(out)
|
|
||||||
out = self.RDB3(out)
|
|
||||||
return out * 0.2 + x
|
|
||||||
|
|
||||||
|
|
||||||
class RRDBNet(nn.Module):
|
|
||||||
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
|
|
||||||
super(RRDBNet, self).__init__()
|
|
||||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
|
||||||
self.sf = sf
|
|
||||||
|
|
||||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
|
||||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
|
||||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
#### upsampling
|
|
||||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
if self.sf==4:
|
|
||||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
|
||||||
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
fea = self.conv_first(x)
|
|
||||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
|
||||||
fea = fea + trunk
|
|
||||||
|
|
||||||
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
|
||||||
if self.sf==4:
|
|
||||||
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
|
||||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
|
||||||
|
|
||||||
return out
|
|
||||||
@ -1,80 +1,463 @@
|
|||||||
# this file is taken from https://github.com/xinntao/ESRGAN
|
# this file is adapted from https://github.com/victorca25/iNNfer
|
||||||
|
|
||||||
|
import math
|
||||||
import functools
|
import functools
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
def make_layer(block, n_layers):
|
####################
|
||||||
layers = []
|
# RRDBNet Generator
|
||||||
for _ in range(n_layers):
|
####################
|
||||||
layers.append(block())
|
|
||||||
return nn.Sequential(*layers)
|
|
||||||
|
|
||||||
|
class RRDBNet(nn.Module):
|
||||||
|
def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
|
||||||
|
act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
|
||||||
|
finalact=None, gaussian_noise=False, plus=False):
|
||||||
|
super(RRDBNet, self).__init__()
|
||||||
|
n_upscale = int(math.log(upscale, 2))
|
||||||
|
if upscale == 3:
|
||||||
|
n_upscale = 1
|
||||||
|
|
||||||
class ResidualDenseBlock_5C(nn.Module):
|
self.resrgan_scale = 0
|
||||||
def __init__(self, nf=64, gc=32, bias=True):
|
if in_nc % 16 == 0:
|
||||||
super(ResidualDenseBlock_5C, self).__init__()
|
self.resrgan_scale = 1
|
||||||
# gc: growth channel, i.e. intermediate channels
|
elif in_nc != 4 and in_nc % 4 == 0:
|
||||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
self.resrgan_scale = 2
|
||||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
# initialization
|
fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
||||||
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||||
|
norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
|
||||||
|
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
|
||||||
|
LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
|
||||||
|
|
||||||
def forward(self, x):
|
if upsample_mode == 'upconv':
|
||||||
x1 = self.lrelu(self.conv1(x))
|
upsample_block = upconv_block
|
||||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
elif upsample_mode == 'pixelshuffle':
|
||||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
upsample_block = pixelshuffle_block
|
||||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
else:
|
||||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
||||||
return x5 * 0.2 + x
|
if upscale == 3:
|
||||||
|
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
||||||
|
else:
|
||||||
|
upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
|
||||||
|
HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
|
||||||
|
HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
||||||
|
|
||||||
|
outact = act(finalact) if finalact else None
|
||||||
|
|
||||||
|
self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
|
||||||
|
*upsampler, HR_conv0, HR_conv1, outact)
|
||||||
|
|
||||||
|
def forward(self, x, outm=None):
|
||||||
|
if self.resrgan_scale == 1:
|
||||||
|
feat = pixel_unshuffle(x, scale=4)
|
||||||
|
elif self.resrgan_scale == 2:
|
||||||
|
feat = pixel_unshuffle(x, scale=2)
|
||||||
|
else:
|
||||||
|
feat = x
|
||||||
|
|
||||||
|
return self.model(feat)
|
||||||
|
|
||||||
|
|
||||||
class RRDB(nn.Module):
|
class RRDB(nn.Module):
|
||||||
'''Residual in Residual Dense Block'''
|
"""
|
||||||
|
Residual in Residual Dense Block
|
||||||
|
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, nf, gc=32):
|
def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||||
|
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
||||||
|
spectral_norm=False, gaussian_noise=False, plus=False):
|
||||||
super(RRDB, self).__init__()
|
super(RRDB, self).__init__()
|
||||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
# This is for backwards compatibility with existing models
|
||||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
if nr == 3:
|
||||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||||
|
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||||
|
gaussian_noise=gaussian_noise, plus=plus)
|
||||||
|
self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||||
|
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||||
|
gaussian_noise=gaussian_noise, plus=plus)
|
||||||
|
self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||||
|
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||||
|
gaussian_noise=gaussian_noise, plus=plus)
|
||||||
|
else:
|
||||||
|
RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||||
|
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||||
|
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
|
||||||
|
self.RDBs = nn.Sequential(*RDB_list)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out = self.RDB1(x)
|
if hasattr(self, 'RDB1'):
|
||||||
out = self.RDB2(out)
|
out = self.RDB1(x)
|
||||||
out = self.RDB3(out)
|
out = self.RDB2(out)
|
||||||
|
out = self.RDB3(out)
|
||||||
|
else:
|
||||||
|
out = self.RDBs(x)
|
||||||
return out * 0.2 + x
|
return out * 0.2 + x
|
||||||
|
|
||||||
|
|
||||||
class RRDBNet(nn.Module):
|
class ResidualDenseBlock_5C(nn.Module):
|
||||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
|
"""
|
||||||
super(RRDBNet, self).__init__()
|
Residual Dense Block
|
||||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
||||||
|
Modified options that can be used:
|
||||||
|
- "Partial Convolution based Padding" arXiv:1811.11718
|
||||||
|
- "Spectral normalization" arXiv:1802.05957
|
||||||
|
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
||||||
|
{Rakotonirina} and A. {Rasoanaivo}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||||
|
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
||||||
|
spectral_norm=False, gaussian_noise=False, plus=False):
|
||||||
|
super(ResidualDenseBlock_5C, self).__init__()
|
||||||
|
|
||||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
self.noise = GaussianNoise() if gaussian_noise else None
|
||||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
self.conv1x1 = conv1x1(nf, gc) if plus else None
|
||||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
#### upsampling
|
|
||||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
|
||||||
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||||
|
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||||
|
spectral_norm=spectral_norm)
|
||||||
|
self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||||
|
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||||
|
spectral_norm=spectral_norm)
|
||||||
|
self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||||
|
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||||
|
spectral_norm=spectral_norm)
|
||||||
|
self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||||
|
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||||
|
spectral_norm=spectral_norm)
|
||||||
|
if mode == 'CNA':
|
||||||
|
last_act = None
|
||||||
|
else:
|
||||||
|
last_act = act_type
|
||||||
|
self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
|
||||||
|
norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
|
||||||
|
spectral_norm=spectral_norm)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
fea = self.conv_first(x)
|
x1 = self.conv1(x)
|
||||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
x2 = self.conv2(torch.cat((x, x1), 1))
|
||||||
fea = fea + trunk
|
if self.conv1x1:
|
||||||
|
x2 = x2 + self.conv1x1(x)
|
||||||
|
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
||||||
|
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
||||||
|
if self.conv1x1:
|
||||||
|
x4 = x4 + x2
|
||||||
|
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||||
|
if self.noise:
|
||||||
|
return self.noise(x5.mul(0.2) + x)
|
||||||
|
else:
|
||||||
|
return x5 * 0.2 + x
|
||||||
|
|
||||||
|
|
||||||
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
####################
|
||||||
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
# ESRGANplus
|
||||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
####################
|
||||||
|
|
||||||
|
class GaussianNoise(nn.Module):
|
||||||
|
def __init__(self, sigma=0.1, is_relative_detach=False):
|
||||||
|
super().__init__()
|
||||||
|
self.sigma = sigma
|
||||||
|
self.is_relative_detach = is_relative_detach
|
||||||
|
self.noise = torch.tensor(0, dtype=torch.float)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.training and self.sigma != 0:
|
||||||
|
self.noise = self.noise.to(x.device)
|
||||||
|
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
||||||
|
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
||||||
|
x = x + sampled_noise
|
||||||
|
return x
|
||||||
|
|
||||||
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# SRVGGNetCompact
|
||||||
|
####################
|
||||||
|
|
||||||
|
class SRVGGNetCompact(nn.Module):
|
||||||
|
"""A compact VGG-style network structure for super-resolution.
|
||||||
|
This class is copied from https://github.com/xinntao/Real-ESRGAN
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
||||||
|
super(SRVGGNetCompact, self).__init__()
|
||||||
|
self.num_in_ch = num_in_ch
|
||||||
|
self.num_out_ch = num_out_ch
|
||||||
|
self.num_feat = num_feat
|
||||||
|
self.num_conv = num_conv
|
||||||
|
self.upscale = upscale
|
||||||
|
self.act_type = act_type
|
||||||
|
|
||||||
|
self.body = nn.ModuleList()
|
||||||
|
# the first conv
|
||||||
|
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
||||||
|
# the first activation
|
||||||
|
if act_type == 'relu':
|
||||||
|
activation = nn.ReLU(inplace=True)
|
||||||
|
elif act_type == 'prelu':
|
||||||
|
activation = nn.PReLU(num_parameters=num_feat)
|
||||||
|
elif act_type == 'leakyrelu':
|
||||||
|
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
|
self.body.append(activation)
|
||||||
|
|
||||||
|
# the body structure
|
||||||
|
for _ in range(num_conv):
|
||||||
|
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
||||||
|
# activation
|
||||||
|
if act_type == 'relu':
|
||||||
|
activation = nn.ReLU(inplace=True)
|
||||||
|
elif act_type == 'prelu':
|
||||||
|
activation = nn.PReLU(num_parameters=num_feat)
|
||||||
|
elif act_type == 'leakyrelu':
|
||||||
|
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
|
self.body.append(activation)
|
||||||
|
|
||||||
|
# the last conv
|
||||||
|
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
||||||
|
# upsample
|
||||||
|
self.upsampler = nn.PixelShuffle(upscale)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = x
|
||||||
|
for i in range(0, len(self.body)):
|
||||||
|
out = self.body[i](out)
|
||||||
|
|
||||||
|
out = self.upsampler(out)
|
||||||
|
# add the nearest upsampled image, so that the network learns the residual
|
||||||
|
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
||||||
|
out += base
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# Upsampler
|
||||||
|
####################
|
||||||
|
|
||||||
|
class Upsample(nn.Module):
|
||||||
|
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
|
||||||
|
The input data is assumed to be of the form
|
||||||
|
`minibatch x channels x [optional depth] x [optional height] x width`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||||
|
super(Upsample, self).__init__()
|
||||||
|
if isinstance(scale_factor, tuple):
|
||||||
|
self.scale_factor = tuple(float(factor) for factor in scale_factor)
|
||||||
|
else:
|
||||||
|
self.scale_factor = float(scale_factor) if scale_factor else None
|
||||||
|
self.mode = mode
|
||||||
|
self.size = size
|
||||||
|
self.align_corners = align_corners
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
if self.scale_factor is not None:
|
||||||
|
info = 'scale_factor=' + str(self.scale_factor)
|
||||||
|
else:
|
||||||
|
info = 'size=' + str(self.size)
|
||||||
|
info += ', mode=' + self.mode
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def pixel_unshuffle(x, scale):
|
||||||
|
""" Pixel unshuffle.
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input feature with shape (b, c, hh, hw).
|
||||||
|
scale (int): Downsample ratio.
|
||||||
|
Returns:
|
||||||
|
Tensor: the pixel unshuffled feature.
|
||||||
|
"""
|
||||||
|
b, c, hh, hw = x.size()
|
||||||
|
out_channel = c * (scale**2)
|
||||||
|
assert hh % scale == 0 and hw % scale == 0
|
||||||
|
h = hh // scale
|
||||||
|
w = hw // scale
|
||||||
|
x_view = x.view(b, c, h, scale, w, scale)
|
||||||
|
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
||||||
|
|
||||||
|
|
||||||
|
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
||||||
|
pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
|
||||||
|
"""
|
||||||
|
Pixel shuffle layer
|
||||||
|
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
||||||
|
Neural Network, CVPR17)
|
||||||
|
"""
|
||||||
|
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
|
||||||
|
pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
|
||||||
|
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||||
|
|
||||||
|
n = norm(norm_type, out_nc) if norm_type else None
|
||||||
|
a = act(act_type) if act_type else None
|
||||||
|
return sequential(conv, pixel_shuffle, n, a)
|
||||||
|
|
||||||
|
|
||||||
|
def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
||||||
|
pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
|
||||||
|
""" Upconv layer """
|
||||||
|
upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
|
||||||
|
upsample = Upsample(scale_factor=upscale_factor, mode=mode)
|
||||||
|
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
|
||||||
|
pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
|
||||||
|
return sequential(upsample, conv)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# Basic blocks
|
||||||
|
####################
|
||||||
|
|
||||||
|
|
||||||
|
def make_layer(basic_block, num_basic_block, **kwarg):
|
||||||
|
"""Make layers by stacking the same blocks.
|
||||||
|
Args:
|
||||||
|
basic_block (nn.module): nn.module class for basic block. (block)
|
||||||
|
num_basic_block (int): number of blocks. (n_layers)
|
||||||
|
Returns:
|
||||||
|
nn.Sequential: Stacked blocks in nn.Sequential.
|
||||||
|
"""
|
||||||
|
layers = []
|
||||||
|
for _ in range(num_basic_block):
|
||||||
|
layers.append(basic_block(**kwarg))
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
|
||||||
|
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
|
||||||
|
""" activation helper """
|
||||||
|
act_type = act_type.lower()
|
||||||
|
if act_type == 'relu':
|
||||||
|
layer = nn.ReLU(inplace)
|
||||||
|
elif act_type in ('leakyrelu', 'lrelu'):
|
||||||
|
layer = nn.LeakyReLU(neg_slope, inplace)
|
||||||
|
elif act_type == 'prelu':
|
||||||
|
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
||||||
|
elif act_type == 'tanh': # [-1, 1] range output
|
||||||
|
layer = nn.Tanh()
|
||||||
|
elif act_type == 'sigmoid': # [0, 1] range output
|
||||||
|
layer = nn.Sigmoid()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
|
||||||
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
class Identity(nn.Module):
|
||||||
|
def __init__(self, *kwargs):
|
||||||
|
super(Identity, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x, *kwargs):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def norm(norm_type, nc):
|
||||||
|
""" Return a normalization layer """
|
||||||
|
norm_type = norm_type.lower()
|
||||||
|
if norm_type == 'batch':
|
||||||
|
layer = nn.BatchNorm2d(nc, affine=True)
|
||||||
|
elif norm_type == 'instance':
|
||||||
|
layer = nn.InstanceNorm2d(nc, affine=False)
|
||||||
|
elif norm_type == 'none':
|
||||||
|
def norm_layer(x): return Identity()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
|
||||||
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
def pad(pad_type, padding):
|
||||||
|
""" padding layer helper """
|
||||||
|
pad_type = pad_type.lower()
|
||||||
|
if padding == 0:
|
||||||
|
return None
|
||||||
|
if pad_type == 'reflect':
|
||||||
|
layer = nn.ReflectionPad2d(padding)
|
||||||
|
elif pad_type == 'replicate':
|
||||||
|
layer = nn.ReplicationPad2d(padding)
|
||||||
|
elif pad_type == 'zero':
|
||||||
|
layer = nn.ZeroPad2d(padding)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
|
||||||
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
def get_valid_padding(kernel_size, dilation):
|
||||||
|
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
||||||
|
padding = (kernel_size - 1) // 2
|
||||||
|
return padding
|
||||||
|
|
||||||
|
|
||||||
|
class ShortcutBlock(nn.Module):
|
||||||
|
""" Elementwise sum the output of a submodule to its input """
|
||||||
|
def __init__(self, submodule):
|
||||||
|
super(ShortcutBlock, self).__init__()
|
||||||
|
self.sub = submodule
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = x + self.sub(x)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
|
||||||
|
|
||||||
|
|
||||||
|
def sequential(*args):
|
||||||
|
""" Flatten Sequential. It unwraps nn.Sequential. """
|
||||||
|
if len(args) == 1:
|
||||||
|
if isinstance(args[0], OrderedDict):
|
||||||
|
raise NotImplementedError('sequential does not support OrderedDict input.')
|
||||||
|
return args[0] # No sequential is needed.
|
||||||
|
modules = []
|
||||||
|
for module in args:
|
||||||
|
if isinstance(module, nn.Sequential):
|
||||||
|
for submodule in module.children():
|
||||||
|
modules.append(submodule)
|
||||||
|
elif isinstance(module, nn.Module):
|
||||||
|
modules.append(module)
|
||||||
|
return nn.Sequential(*modules)
|
||||||
|
|
||||||
|
|
||||||
|
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
|
||||||
|
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
||||||
|
spectral_norm=False):
|
||||||
|
""" Conv layer with padding, normalization, activation """
|
||||||
|
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
|
||||||
|
padding = get_valid_padding(kernel_size, dilation)
|
||||||
|
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
||||||
|
padding = padding if pad_type == 'zero' else 0
|
||||||
|
|
||||||
|
if convtype=='PartialConv2D':
|
||||||
|
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
|
elif convtype=='DeformConv2D':
|
||||||
|
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
|
elif convtype=='Conv3D':
|
||||||
|
c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
|
else:
|
||||||
|
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
|
|
||||||
|
if spectral_norm:
|
||||||
|
c = nn.utils.spectral_norm(c)
|
||||||
|
|
||||||
|
a = act(act_type) if act_type else None
|
||||||
|
if 'CNA' in mode:
|
||||||
|
n = norm(norm_type, out_nc) if norm_type else None
|
||||||
|
return sequential(p, c, n, a)
|
||||||
|
elif mode == 'NAC':
|
||||||
|
if norm_type is None and act_type is not None:
|
||||||
|
a = act(act_type, inplace=False)
|
||||||
|
n = norm(norm_type, in_nc) if norm_type else None
|
||||||
|
return sequential(n, a, p, c)
|
||||||
|
|||||||
@ -0,0 +1,83 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import git
|
||||||
|
|
||||||
|
from modules import paths, shared
|
||||||
|
|
||||||
|
|
||||||
|
extensions = []
|
||||||
|
extensions_dir = os.path.join(paths.script_path, "extensions")
|
||||||
|
|
||||||
|
|
||||||
|
def active():
|
||||||
|
return [x for x in extensions if x.enabled]
|
||||||
|
|
||||||
|
|
||||||
|
class Extension:
|
||||||
|
def __init__(self, name, path, enabled=True):
|
||||||
|
self.name = name
|
||||||
|
self.path = path
|
||||||
|
self.enabled = enabled
|
||||||
|
self.status = ''
|
||||||
|
self.can_update = False
|
||||||
|
|
||||||
|
repo = None
|
||||||
|
try:
|
||||||
|
if os.path.exists(os.path.join(path, ".git")):
|
||||||
|
repo = git.Repo(path)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error reading github repository info from {path}:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
if repo is None or repo.bare:
|
||||||
|
self.remote = None
|
||||||
|
else:
|
||||||
|
self.remote = next(repo.remote().urls, None)
|
||||||
|
self.status = 'unknown'
|
||||||
|
|
||||||
|
def list_files(self, subdir, extension):
|
||||||
|
from modules import scripts
|
||||||
|
|
||||||
|
dirpath = os.path.join(self.path, subdir)
|
||||||
|
if not os.path.isdir(dirpath):
|
||||||
|
return []
|
||||||
|
|
||||||
|
res = []
|
||||||
|
for filename in sorted(os.listdir(dirpath)):
|
||||||
|
res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
|
||||||
|
|
||||||
|
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def check_updates(self):
|
||||||
|
repo = git.Repo(self.path)
|
||||||
|
for fetch in repo.remote().fetch("--dry-run"):
|
||||||
|
if fetch.flags != fetch.HEAD_UPTODATE:
|
||||||
|
self.can_update = True
|
||||||
|
self.status = "behind"
|
||||||
|
return
|
||||||
|
|
||||||
|
self.can_update = False
|
||||||
|
self.status = "latest"
|
||||||
|
|
||||||
|
def pull(self):
|
||||||
|
repo = git.Repo(self.path)
|
||||||
|
repo.remotes.origin.pull()
|
||||||
|
|
||||||
|
|
||||||
|
def list_extensions():
|
||||||
|
extensions.clear()
|
||||||
|
|
||||||
|
if not os.path.isdir(extensions_dir):
|
||||||
|
return
|
||||||
|
|
||||||
|
for dirname in sorted(os.listdir(extensions_dir)):
|
||||||
|
path = os.path.join(extensions_dir, dirname)
|
||||||
|
if not os.path.isdir(path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
|
||||||
|
extensions.append(extension)
|
||||||
@ -1,183 +0,0 @@
|
|||||||
import os
|
|
||||||
import shutil
|
|
||||||
import sys
|
|
||||||
|
|
||||||
def traverse_all_files(output_dir, image_list, curr_dir=None):
|
|
||||||
curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
|
|
||||||
try:
|
|
||||||
f_list = os.listdir(curr_path)
|
|
||||||
except:
|
|
||||||
if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt":
|
|
||||||
image_list.append(curr_dir)
|
|
||||||
return image_list
|
|
||||||
for file in f_list:
|
|
||||||
file = file if curr_dir is None else os.path.join(curr_dir, file)
|
|
||||||
file_path = os.path.join(curr_path, file)
|
|
||||||
if file[-4:] == ".txt":
|
|
||||||
pass
|
|
||||||
elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0:
|
|
||||||
image_list.append(file)
|
|
||||||
else:
|
|
||||||
image_list = traverse_all_files(output_dir, image_list, file)
|
|
||||||
return image_list
|
|
||||||
|
|
||||||
|
|
||||||
def get_recent_images(dir_name, page_index, step, image_index, tabname):
|
|
||||||
page_index = int(page_index)
|
|
||||||
image_list = []
|
|
||||||
if not os.path.exists(dir_name):
|
|
||||||
pass
|
|
||||||
elif os.path.isdir(dir_name):
|
|
||||||
image_list = traverse_all_files(dir_name, image_list)
|
|
||||||
image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file)))
|
|
||||||
else:
|
|
||||||
print(f'ERROR: "{dir_name}" is not a directory. Check the path in the settings.', file=sys.stderr)
|
|
||||||
num = 48 if tabname != "extras" else 12
|
|
||||||
max_page_index = len(image_list) // num + 1
|
|
||||||
page_index = max_page_index if page_index == -1 else page_index + step
|
|
||||||
page_index = 1 if page_index < 1 else page_index
|
|
||||||
page_index = max_page_index if page_index > max_page_index else page_index
|
|
||||||
idx_frm = (page_index - 1) * num
|
|
||||||
image_list = image_list[idx_frm:idx_frm + num]
|
|
||||||
image_index = int(image_index)
|
|
||||||
if image_index < 0 or image_index > len(image_list) - 1:
|
|
||||||
current_file = None
|
|
||||||
hidden = None
|
|
||||||
else:
|
|
||||||
current_file = image_list[int(image_index)]
|
|
||||||
hidden = os.path.join(dir_name, current_file)
|
|
||||||
return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, ""
|
|
||||||
|
|
||||||
|
|
||||||
def first_page_click(dir_name, page_index, image_index, tabname):
|
|
||||||
return get_recent_images(dir_name, 1, 0, image_index, tabname)
|
|
||||||
|
|
||||||
|
|
||||||
def end_page_click(dir_name, page_index, image_index, tabname):
|
|
||||||
return get_recent_images(dir_name, -1, 0, image_index, tabname)
|
|
||||||
|
|
||||||
|
|
||||||
def prev_page_click(dir_name, page_index, image_index, tabname):
|
|
||||||
return get_recent_images(dir_name, page_index, -1, image_index, tabname)
|
|
||||||
|
|
||||||
|
|
||||||
def next_page_click(dir_name, page_index, image_index, tabname):
|
|
||||||
return get_recent_images(dir_name, page_index, 1, image_index, tabname)
|
|
||||||
|
|
||||||
|
|
||||||
def page_index_change(dir_name, page_index, image_index, tabname):
|
|
||||||
return get_recent_images(dir_name, page_index, 0, image_index, tabname)
|
|
||||||
|
|
||||||
|
|
||||||
def show_image_info(num, image_path, filenames):
|
|
||||||
# print(f"select image {num}")
|
|
||||||
file = filenames[int(num)]
|
|
||||||
return file, num, os.path.join(image_path, file)
|
|
||||||
|
|
||||||
|
|
||||||
def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index):
|
|
||||||
if name == "":
|
|
||||||
return filenames, delete_num
|
|
||||||
else:
|
|
||||||
delete_num = int(delete_num)
|
|
||||||
index = list(filenames).index(name)
|
|
||||||
i = 0
|
|
||||||
new_file_list = []
|
|
||||||
for name in filenames:
|
|
||||||
if i >= index and i < index + delete_num:
|
|
||||||
path = os.path.join(dir_name, name)
|
|
||||||
if os.path.exists(path):
|
|
||||||
print(f"Delete file {path}")
|
|
||||||
os.remove(path)
|
|
||||||
txt_file = os.path.splitext(path)[0] + ".txt"
|
|
||||||
if os.path.exists(txt_file):
|
|
||||||
os.remove(txt_file)
|
|
||||||
else:
|
|
||||||
print(f"Not exists file {path}")
|
|
||||||
else:
|
|
||||||
new_file_list.append(name)
|
|
||||||
i += 1
|
|
||||||
return new_file_list, 1
|
|
||||||
|
|
||||||
|
|
||||||
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
|
|
||||||
if opts.outdir_samples != "":
|
|
||||||
dir_name = opts.outdir_samples
|
|
||||||
elif tabname == "txt2img":
|
|
||||||
dir_name = opts.outdir_txt2img_samples
|
|
||||||
elif tabname == "img2img":
|
|
||||||
dir_name = opts.outdir_img2img_samples
|
|
||||||
elif tabname == "extras":
|
|
||||||
dir_name = opts.outdir_extras_samples
|
|
||||||
else:
|
|
||||||
return
|
|
||||||
with gr.Row():
|
|
||||||
renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page")
|
|
||||||
first_page = gr.Button('First Page')
|
|
||||||
prev_page = gr.Button('Prev Page')
|
|
||||||
page_index = gr.Number(value=1, label="Page Index")
|
|
||||||
next_page = gr.Button('Next Page')
|
|
||||||
end_page = gr.Button('End Page')
|
|
||||||
with gr.Row(elem_id=tabname + "_images_history"):
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column(scale=2):
|
|
||||||
history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
|
|
||||||
with gr.Row():
|
|
||||||
delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
|
|
||||||
delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Row():
|
|
||||||
pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
|
|
||||||
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
img_file_info = gr.Textbox(label="Generate Info", interactive=False)
|
|
||||||
img_file_name = gr.Textbox(label="File Name", interactive=False)
|
|
||||||
with gr.Row():
|
|
||||||
# hiden items
|
|
||||||
|
|
||||||
img_path = gr.Textbox(dir_name.rstrip("/"), visible=False)
|
|
||||||
tabname_box = gr.Textbox(tabname, visible=False)
|
|
||||||
image_index = gr.Textbox(value=-1, visible=False)
|
|
||||||
set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False)
|
|
||||||
filenames = gr.State()
|
|
||||||
hidden = gr.Image(type="pil", visible=False)
|
|
||||||
info1 = gr.Textbox(visible=False)
|
|
||||||
info2 = gr.Textbox(visible=False)
|
|
||||||
|
|
||||||
# turn pages
|
|
||||||
gallery_inputs = [img_path, page_index, image_index, tabname_box]
|
|
||||||
gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name]
|
|
||||||
|
|
||||||
first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
|
||||||
next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
|
||||||
prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
|
||||||
end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
|
||||||
page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
|
||||||
renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
|
||||||
# page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index])
|
|
||||||
|
|
||||||
# other funcitons
|
|
||||||
set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden])
|
|
||||||
img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
|
|
||||||
delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
|
|
||||||
hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
|
|
||||||
|
|
||||||
# pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
|
|
||||||
switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
|
|
||||||
switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
|
|
||||||
|
|
||||||
|
|
||||||
def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
|
|
||||||
with gr.Blocks(analytics_enabled=False) as images_history:
|
|
||||||
with gr.Tabs() as tabs:
|
|
||||||
with gr.Tab("txt2img history"):
|
|
||||||
with gr.Blocks(analytics_enabled=False) as images_history_txt2img:
|
|
||||||
show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict)
|
|
||||||
with gr.Tab("img2img history"):
|
|
||||||
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
|
|
||||||
show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict)
|
|
||||||
with gr.Tab("extras history"):
|
|
||||||
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
|
|
||||||
show_images_history(gr, opts, "extras", run_pnginfo, switch_dict)
|
|
||||||
return images_history
|
|
||||||
@ -0,0 +1,132 @@
|
|||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from collections import namedtuple
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
|
def report_exception(c, job):
|
||||||
|
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageSaveParams:
|
||||||
|
def __init__(self, image, p, filename, pnginfo):
|
||||||
|
self.image = image
|
||||||
|
"""the PIL image itself"""
|
||||||
|
|
||||||
|
self.p = p
|
||||||
|
"""p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
|
||||||
|
|
||||||
|
self.filename = filename
|
||||||
|
"""name of file that the image would be saved to"""
|
||||||
|
|
||||||
|
self.pnginfo = pnginfo
|
||||||
|
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
|
||||||
|
|
||||||
|
|
||||||
|
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
||||||
|
callbacks_model_loaded = []
|
||||||
|
callbacks_ui_tabs = []
|
||||||
|
callbacks_ui_settings = []
|
||||||
|
callbacks_before_image_saved = []
|
||||||
|
callbacks_image_saved = []
|
||||||
|
|
||||||
|
|
||||||
|
def clear_callbacks():
|
||||||
|
callbacks_model_loaded.clear()
|
||||||
|
callbacks_ui_tabs.clear()
|
||||||
|
callbacks_ui_settings.clear()
|
||||||
|
callbacks_before_image_saved.clear()
|
||||||
|
callbacks_image_saved.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def model_loaded_callback(sd_model):
|
||||||
|
for c in callbacks_model_loaded:
|
||||||
|
try:
|
||||||
|
c.callback(sd_model)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'model_loaded_callback')
|
||||||
|
|
||||||
|
|
||||||
|
def ui_tabs_callback():
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for c in callbacks_ui_tabs:
|
||||||
|
try:
|
||||||
|
res += c.callback() or []
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'ui_tabs_callback')
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def ui_settings_callback():
|
||||||
|
for c in callbacks_ui_settings:
|
||||||
|
try:
|
||||||
|
c.callback()
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'ui_settings_callback')
|
||||||
|
|
||||||
|
|
||||||
|
def before_image_saved_callback(params: ImageSaveParams):
|
||||||
|
for c in callbacks_image_saved:
|
||||||
|
try:
|
||||||
|
c.callback(params)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'before_image_saved_callback')
|
||||||
|
|
||||||
|
|
||||||
|
def image_saved_callback(params: ImageSaveParams):
|
||||||
|
for c in callbacks_image_saved:
|
||||||
|
try:
|
||||||
|
c.callback(params)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'image_saved_callback')
|
||||||
|
|
||||||
|
|
||||||
|
def add_callback(callbacks, fun):
|
||||||
|
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||||
|
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||||
|
|
||||||
|
callbacks.append(ScriptCallback(filename, fun))
|
||||||
|
|
||||||
|
|
||||||
|
def on_model_loaded(callback):
|
||||||
|
"""register a function to be called when the stable diffusion model is created; the model is
|
||||||
|
passed as an argument"""
|
||||||
|
add_callback(callbacks_model_loaded, callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_ui_tabs(callback):
|
||||||
|
"""register a function to be called when the UI is creating new tabs.
|
||||||
|
The function must either return a None, which means no new tabs to be added, or a list, where
|
||||||
|
each element is a tuple:
|
||||||
|
(gradio_component, title, elem_id)
|
||||||
|
|
||||||
|
gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
|
||||||
|
title is tab text displayed to user in the UI
|
||||||
|
elem_id is HTML id for the tab
|
||||||
|
"""
|
||||||
|
add_callback(callbacks_ui_tabs, callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_ui_settings(callback):
|
||||||
|
"""register a function to be called before UI settings are populated; add your settings
|
||||||
|
by using shared.opts.add_option(shared.OptionInfo(...)) """
|
||||||
|
add_callback(callbacks_ui_settings, callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_before_image_saved(callback):
|
||||||
|
"""register a function to be called before an image is saved to a file.
|
||||||
|
The callback is called with one argument:
|
||||||
|
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
|
||||||
|
"""
|
||||||
|
add_callback(callbacks_before_image_saved, callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_image_saved(callback):
|
||||||
|
"""register a function to be called after an image is saved to a file.
|
||||||
|
The callback is called with one argument:
|
||||||
|
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
|
||||||
|
"""
|
||||||
|
add_callback(callbacks_image_saved, callback)
|
||||||
@ -0,0 +1,341 @@
|
|||||||
|
import cv2
|
||||||
|
import requests
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from math import log, sqrt
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
|
GREEN = "#0F0"
|
||||||
|
BLUE = "#00F"
|
||||||
|
RED = "#F00"
|
||||||
|
|
||||||
|
|
||||||
|
def crop_image(im, settings):
|
||||||
|
""" Intelligently crop an image to the subject matter """
|
||||||
|
|
||||||
|
scale_by = 1
|
||||||
|
if is_landscape(im.width, im.height):
|
||||||
|
scale_by = settings.crop_height / im.height
|
||||||
|
elif is_portrait(im.width, im.height):
|
||||||
|
scale_by = settings.crop_width / im.width
|
||||||
|
elif is_square(im.width, im.height):
|
||||||
|
if is_square(settings.crop_width, settings.crop_height):
|
||||||
|
scale_by = settings.crop_width / im.width
|
||||||
|
elif is_landscape(settings.crop_width, settings.crop_height):
|
||||||
|
scale_by = settings.crop_width / im.width
|
||||||
|
elif is_portrait(settings.crop_width, settings.crop_height):
|
||||||
|
scale_by = settings.crop_height / im.height
|
||||||
|
|
||||||
|
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
|
||||||
|
im_debug = im.copy()
|
||||||
|
|
||||||
|
focus = focal_point(im_debug, settings)
|
||||||
|
|
||||||
|
# take the focal point and turn it into crop coordinates that try to center over the focal
|
||||||
|
# point but then get adjusted back into the frame
|
||||||
|
y_half = int(settings.crop_height / 2)
|
||||||
|
x_half = int(settings.crop_width / 2)
|
||||||
|
|
||||||
|
x1 = focus.x - x_half
|
||||||
|
if x1 < 0:
|
||||||
|
x1 = 0
|
||||||
|
elif x1 + settings.crop_width > im.width:
|
||||||
|
x1 = im.width - settings.crop_width
|
||||||
|
|
||||||
|
y1 = focus.y - y_half
|
||||||
|
if y1 < 0:
|
||||||
|
y1 = 0
|
||||||
|
elif y1 + settings.crop_height > im.height:
|
||||||
|
y1 = im.height - settings.crop_height
|
||||||
|
|
||||||
|
x2 = x1 + settings.crop_width
|
||||||
|
y2 = y1 + settings.crop_height
|
||||||
|
|
||||||
|
crop = [x1, y1, x2, y2]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
results.append(im.crop(tuple(crop)))
|
||||||
|
|
||||||
|
if settings.annotate_image:
|
||||||
|
d = ImageDraw.Draw(im_debug)
|
||||||
|
rect = list(crop)
|
||||||
|
rect[2] -= 1
|
||||||
|
rect[3] -= 1
|
||||||
|
d.rectangle(rect, outline=GREEN)
|
||||||
|
results.append(im_debug)
|
||||||
|
if settings.destop_view_image:
|
||||||
|
im_debug.show()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def focal_point(im, settings):
|
||||||
|
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
|
||||||
|
entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
|
||||||
|
face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []
|
||||||
|
|
||||||
|
pois = []
|
||||||
|
|
||||||
|
weight_pref_total = 0
|
||||||
|
if len(corner_points) > 0:
|
||||||
|
weight_pref_total += settings.corner_points_weight
|
||||||
|
if len(entropy_points) > 0:
|
||||||
|
weight_pref_total += settings.entropy_points_weight
|
||||||
|
if len(face_points) > 0:
|
||||||
|
weight_pref_total += settings.face_points_weight
|
||||||
|
|
||||||
|
corner_centroid = None
|
||||||
|
if len(corner_points) > 0:
|
||||||
|
corner_centroid = centroid(corner_points)
|
||||||
|
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
|
||||||
|
pois.append(corner_centroid)
|
||||||
|
|
||||||
|
entropy_centroid = None
|
||||||
|
if len(entropy_points) > 0:
|
||||||
|
entropy_centroid = centroid(entropy_points)
|
||||||
|
entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
|
||||||
|
pois.append(entropy_centroid)
|
||||||
|
|
||||||
|
face_centroid = None
|
||||||
|
if len(face_points) > 0:
|
||||||
|
face_centroid = centroid(face_points)
|
||||||
|
face_centroid.weight = settings.face_points_weight / weight_pref_total
|
||||||
|
pois.append(face_centroid)
|
||||||
|
|
||||||
|
average_point = poi_average(pois, settings)
|
||||||
|
|
||||||
|
if settings.annotate_image:
|
||||||
|
d = ImageDraw.Draw(im)
|
||||||
|
max_size = min(im.width, im.height) * 0.07
|
||||||
|
if corner_centroid is not None:
|
||||||
|
color = BLUE
|
||||||
|
box = corner_centroid.bounding(max_size * corner_centroid.weight)
|
||||||
|
d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color)
|
||||||
|
d.ellipse(box, outline=color)
|
||||||
|
if len(corner_points) > 1:
|
||||||
|
for f in corner_points:
|
||||||
|
d.rectangle(f.bounding(4), outline=color)
|
||||||
|
if entropy_centroid is not None:
|
||||||
|
color = "#ff0"
|
||||||
|
box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
|
||||||
|
d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color)
|
||||||
|
d.ellipse(box, outline=color)
|
||||||
|
if len(entropy_points) > 1:
|
||||||
|
for f in entropy_points:
|
||||||
|
d.rectangle(f.bounding(4), outline=color)
|
||||||
|
if face_centroid is not None:
|
||||||
|
color = RED
|
||||||
|
box = face_centroid.bounding(max_size * face_centroid.weight)
|
||||||
|
d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color)
|
||||||
|
d.ellipse(box, outline=color)
|
||||||
|
if len(face_points) > 1:
|
||||||
|
for f in face_points:
|
||||||
|
d.rectangle(f.bounding(4), outline=color)
|
||||||
|
|
||||||
|
d.ellipse(average_point.bounding(max_size), outline=GREEN)
|
||||||
|
|
||||||
|
return average_point
|
||||||
|
|
||||||
|
|
||||||
|
def image_face_points(im, settings):
|
||||||
|
if settings.dnn_model_path is not None:
|
||||||
|
detector = cv2.FaceDetectorYN.create(
|
||||||
|
settings.dnn_model_path,
|
||||||
|
"",
|
||||||
|
(im.width, im.height),
|
||||||
|
0.9, # score threshold
|
||||||
|
0.3, # nms threshold
|
||||||
|
5000 # keep top k before nms
|
||||||
|
)
|
||||||
|
faces = detector.detect(np.array(im))
|
||||||
|
results = []
|
||||||
|
if faces[1] is not None:
|
||||||
|
for face in faces[1]:
|
||||||
|
x = face[0]
|
||||||
|
y = face[1]
|
||||||
|
w = face[2]
|
||||||
|
h = face[3]
|
||||||
|
results.append(
|
||||||
|
PointOfInterest(
|
||||||
|
int(x + (w * 0.5)), # face focus left/right is center
|
||||||
|
int(y + (h * 0.33)), # face focus up/down is close to the top of the head
|
||||||
|
size = w,
|
||||||
|
weight = 1/len(faces[1])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
else:
|
||||||
|
np_im = np.array(im)
|
||||||
|
gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
tries = [
|
||||||
|
[ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
|
||||||
|
[ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
|
||||||
|
[ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
|
||||||
|
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
|
||||||
|
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
|
||||||
|
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
|
||||||
|
[ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
|
||||||
|
[ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
|
||||||
|
]
|
||||||
|
for t in tries:
|
||||||
|
classifier = cv2.CascadeClassifier(t[0])
|
||||||
|
minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
|
||||||
|
try:
|
||||||
|
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
|
||||||
|
minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(faces) > 0:
|
||||||
|
rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
|
||||||
|
return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def image_corner_points(im, settings):
|
||||||
|
grayscale = im.convert("L")
|
||||||
|
|
||||||
|
# naive attempt at preventing focal points from collecting at watermarks near the bottom
|
||||||
|
gd = ImageDraw.Draw(grayscale)
|
||||||
|
gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
|
||||||
|
|
||||||
|
np_im = np.array(grayscale)
|
||||||
|
|
||||||
|
points = cv2.goodFeaturesToTrack(
|
||||||
|
np_im,
|
||||||
|
maxCorners=100,
|
||||||
|
qualityLevel=0.04,
|
||||||
|
minDistance=min(grayscale.width, grayscale.height)*0.06,
|
||||||
|
useHarrisDetector=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if points is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
focal_points = []
|
||||||
|
for point in points:
|
||||||
|
x, y = point.ravel()
|
||||||
|
focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points)))
|
||||||
|
|
||||||
|
return focal_points
|
||||||
|
|
||||||
|
|
||||||
|
def image_entropy_points(im, settings):
|
||||||
|
landscape = im.height < im.width
|
||||||
|
portrait = im.height > im.width
|
||||||
|
if landscape:
|
||||||
|
move_idx = [0, 2]
|
||||||
|
move_max = im.size[0]
|
||||||
|
elif portrait:
|
||||||
|
move_idx = [1, 3]
|
||||||
|
move_max = im.size[1]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
e_max = 0
|
||||||
|
crop_current = [0, 0, settings.crop_width, settings.crop_height]
|
||||||
|
crop_best = crop_current
|
||||||
|
while crop_current[move_idx[1]] < move_max:
|
||||||
|
crop = im.crop(tuple(crop_current))
|
||||||
|
e = image_entropy(crop)
|
||||||
|
|
||||||
|
if (e > e_max):
|
||||||
|
e_max = e
|
||||||
|
crop_best = list(crop_current)
|
||||||
|
|
||||||
|
crop_current[move_idx[0]] += 4
|
||||||
|
crop_current[move_idx[1]] += 4
|
||||||
|
|
||||||
|
x_mid = int(crop_best[0] + settings.crop_width/2)
|
||||||
|
y_mid = int(crop_best[1] + settings.crop_height/2)
|
||||||
|
|
||||||
|
return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
|
||||||
|
|
||||||
|
|
||||||
|
def image_entropy(im):
|
||||||
|
# greyscale image entropy
|
||||||
|
# band = np.asarray(im.convert("L"))
|
||||||
|
band = np.asarray(im.convert("1"), dtype=np.uint8)
|
||||||
|
hist, _ = np.histogram(band, bins=range(0, 256))
|
||||||
|
hist = hist[hist > 0]
|
||||||
|
return -np.log2(hist / hist.sum()).sum()
|
||||||
|
|
||||||
|
def centroid(pois):
|
||||||
|
x = [poi.x for poi in pois]
|
||||||
|
y = [poi.y for poi in pois]
|
||||||
|
return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
|
||||||
|
|
||||||
|
|
||||||
|
def poi_average(pois, settings):
|
||||||
|
weight = 0.0
|
||||||
|
x = 0.0
|
||||||
|
y = 0.0
|
||||||
|
for poi in pois:
|
||||||
|
weight += poi.weight
|
||||||
|
x += poi.x * poi.weight
|
||||||
|
y += poi.y * poi.weight
|
||||||
|
avg_x = round(x / weight)
|
||||||
|
avg_y = round(y / weight)
|
||||||
|
|
||||||
|
return PointOfInterest(avg_x, avg_y)
|
||||||
|
|
||||||
|
|
||||||
|
def is_landscape(w, h):
|
||||||
|
return w > h
|
||||||
|
|
||||||
|
|
||||||
|
def is_portrait(w, h):
|
||||||
|
return h > w
|
||||||
|
|
||||||
|
|
||||||
|
def is_square(w, h):
|
||||||
|
return w == h
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_cache_models(dirname):
|
||||||
|
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
||||||
|
model_file_name = 'face_detection_yunet.onnx'
|
||||||
|
|
||||||
|
if not os.path.exists(dirname):
|
||||||
|
os.makedirs(dirname)
|
||||||
|
|
||||||
|
cache_file = os.path.join(dirname, model_file_name)
|
||||||
|
if not os.path.exists(cache_file):
|
||||||
|
print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
|
||||||
|
response = requests.get(download_url)
|
||||||
|
with open(cache_file, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
|
||||||
|
if os.path.exists(cache_file):
|
||||||
|
return cache_file
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class PointOfInterest:
|
||||||
|
def __init__(self, x, y, weight=1.0, size=10):
|
||||||
|
self.x = x
|
||||||
|
self.y = y
|
||||||
|
self.weight = weight
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def bounding(self, size):
|
||||||
|
return [
|
||||||
|
self.x - size//2,
|
||||||
|
self.y - size//2,
|
||||||
|
self.x + size//2,
|
||||||
|
self.y + size//2
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Settings:
|
||||||
|
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
|
||||||
|
self.crop_width = crop_width
|
||||||
|
self.crop_height = crop_height
|
||||||
|
self.corner_points_weight = corner_points_weight
|
||||||
|
self.entropy_points_weight = entropy_points_weight
|
||||||
|
self.face_points_weight = face_points_weight
|
||||||
|
self.annotate_image = annotate_image
|
||||||
|
self.destop_view_image = False
|
||||||
|
self.dnn_model_path = dnn_model_path
|
||||||
@ -0,0 +1,172 @@
|
|||||||
|
import json
|
||||||
|
import os.path
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import git
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import html
|
||||||
|
|
||||||
|
from modules import extensions, shared, paths
|
||||||
|
|
||||||
|
|
||||||
|
def check_access():
|
||||||
|
assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags"
|
||||||
|
|
||||||
|
|
||||||
|
def apply_and_restart(disable_list, update_list):
|
||||||
|
check_access()
|
||||||
|
|
||||||
|
disabled = json.loads(disable_list)
|
||||||
|
assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
|
||||||
|
|
||||||
|
update = json.loads(update_list)
|
||||||
|
assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
|
||||||
|
|
||||||
|
update = set(update)
|
||||||
|
|
||||||
|
for ext in extensions.extensions:
|
||||||
|
if ext.name not in update:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
ext.pull()
|
||||||
|
except Exception:
|
||||||
|
print(f"Error pulling updates for {ext.name}:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
shared.opts.disabled_extensions = disabled
|
||||||
|
shared.opts.save(shared.config_filename)
|
||||||
|
|
||||||
|
shared.state.interrupt()
|
||||||
|
shared.state.need_restart = True
|
||||||
|
|
||||||
|
|
||||||
|
def check_updates():
|
||||||
|
check_access()
|
||||||
|
|
||||||
|
for ext in extensions.extensions:
|
||||||
|
if ext.remote is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
ext.check_updates()
|
||||||
|
except Exception:
|
||||||
|
print(f"Error checking updates for {ext.name}:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
return extension_table()
|
||||||
|
|
||||||
|
|
||||||
|
def extension_table():
|
||||||
|
code = f"""<!-- {time.time()} -->
|
||||||
|
<table id="extensions">
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
|
||||||
|
<th>URL</th>
|
||||||
|
<th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
"""
|
||||||
|
|
||||||
|
for ext in extensions.extensions:
|
||||||
|
if ext.can_update:
|
||||||
|
ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
|
||||||
|
else:
|
||||||
|
ext_status = ext.status
|
||||||
|
|
||||||
|
code += f"""
|
||||||
|
<tr>
|
||||||
|
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
|
||||||
|
<td><a href="{html.escape(ext.remote or '')}">{html.escape(ext.remote or '')}</a></td>
|
||||||
|
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
|
||||||
|
</tr>
|
||||||
|
"""
|
||||||
|
|
||||||
|
code += """
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
"""
|
||||||
|
|
||||||
|
return code
|
||||||
|
|
||||||
|
|
||||||
|
def install_extension_from_url(dirname, url):
|
||||||
|
check_access()
|
||||||
|
|
||||||
|
assert url, 'No URL specified'
|
||||||
|
|
||||||
|
if dirname is None or dirname == "":
|
||||||
|
*parts, last_part = url.split('/')
|
||||||
|
last_part = last_part.replace(".git", "")
|
||||||
|
|
||||||
|
dirname = last_part
|
||||||
|
|
||||||
|
target_dir = os.path.join(extensions.extensions_dir, dirname)
|
||||||
|
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
|
||||||
|
|
||||||
|
assert len([x for x in extensions.extensions if x.remote == url]) == 0, 'Extension with this URL is already installed'
|
||||||
|
|
||||||
|
tmpdir = os.path.join(paths.script_path, "tmp", dirname)
|
||||||
|
|
||||||
|
try:
|
||||||
|
shutil.rmtree(tmpdir, True)
|
||||||
|
|
||||||
|
repo = git.Repo.clone_from(url, tmpdir)
|
||||||
|
repo.remote().fetch()
|
||||||
|
|
||||||
|
os.rename(tmpdir, target_dir)
|
||||||
|
|
||||||
|
extensions.list_extensions()
|
||||||
|
return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
|
||||||
|
finally:
|
||||||
|
shutil.rmtree(tmpdir, True)
|
||||||
|
|
||||||
|
|
||||||
|
def create_ui():
|
||||||
|
import modules.ui
|
||||||
|
|
||||||
|
with gr.Blocks(analytics_enabled=False) as ui:
|
||||||
|
with gr.Tabs(elem_id="tabs_extensions") as tabs:
|
||||||
|
with gr.TabItem("Installed"):
|
||||||
|
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False)
|
||||||
|
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
apply = gr.Button(value="Apply and restart UI", variant="primary")
|
||||||
|
check = gr.Button(value="Check for updates")
|
||||||
|
|
||||||
|
extensions_table = gr.HTML(lambda: extension_table())
|
||||||
|
|
||||||
|
apply.click(
|
||||||
|
fn=apply_and_restart,
|
||||||
|
_js="extensions_apply",
|
||||||
|
inputs=[extensions_disabled_list, extensions_update_list],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
check.click(
|
||||||
|
fn=check_updates,
|
||||||
|
_js="extensions_check",
|
||||||
|
inputs=[],
|
||||||
|
outputs=[extensions_table],
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.TabItem("Install from URL"):
|
||||||
|
install_url = gr.Text(label="URL for extension's git repository")
|
||||||
|
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
|
||||||
|
intall_button = gr.Button(value="Install", variant="primary")
|
||||||
|
intall_result = gr.HTML(elem_id="extension_install_result")
|
||||||
|
|
||||||
|
intall_button.click(
|
||||||
|
fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
|
||||||
|
inputs=[install_dirname, install_url],
|
||||||
|
outputs=[extensions_table, intall_result],
|
||||||
|
)
|
||||||
|
|
||||||
|
return ui
|
||||||
@ -0,0 +1,29 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtrasWorking(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.url_img2img = "http://localhost:7860/sdapi/v1/extra-single-image"
|
||||||
|
self.simple_extras = {
|
||||||
|
"resize_mode": 0,
|
||||||
|
"show_extras_results": True,
|
||||||
|
"gfpgan_visibility": 0,
|
||||||
|
"codeformer_visibility": 0,
|
||||||
|
"codeformer_weight": 0,
|
||||||
|
"upscaling_resize": 2,
|
||||||
|
"upscaling_resize_w": 512,
|
||||||
|
"upscaling_resize_h": 512,
|
||||||
|
"upscaling_crop": True,
|
||||||
|
"upscaler_1": "None",
|
||||||
|
"upscaler_2": "None",
|
||||||
|
"extras_upscaler_2_visibility": 0,
|
||||||
|
"image": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtrasCorrectness(unittest.TestCase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@ -0,0 +1,59 @@
|
|||||||
|
import unittest
|
||||||
|
import requests
|
||||||
|
from gradio.processing_utils import encode_pil_to_base64
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class TestImg2ImgWorking(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.url_img2img = "http://localhost:7860/sdapi/v1/img2img"
|
||||||
|
self.simple_img2img = {
|
||||||
|
"init_images": [encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))],
|
||||||
|
"resize_mode": 0,
|
||||||
|
"denoising_strength": 0.75,
|
||||||
|
"mask": None,
|
||||||
|
"mask_blur": 4,
|
||||||
|
"inpainting_fill": 0,
|
||||||
|
"inpaint_full_res": False,
|
||||||
|
"inpaint_full_res_padding": 0,
|
||||||
|
"inpainting_mask_invert": 0,
|
||||||
|
"prompt": "example prompt",
|
||||||
|
"styles": [],
|
||||||
|
"seed": -1,
|
||||||
|
"subseed": -1,
|
||||||
|
"subseed_strength": 0,
|
||||||
|
"seed_resize_from_h": -1,
|
||||||
|
"seed_resize_from_w": -1,
|
||||||
|
"batch_size": 1,
|
||||||
|
"n_iter": 1,
|
||||||
|
"steps": 3,
|
||||||
|
"cfg_scale": 7,
|
||||||
|
"width": 64,
|
||||||
|
"height": 64,
|
||||||
|
"restore_faces": False,
|
||||||
|
"tiling": False,
|
||||||
|
"negative_prompt": "",
|
||||||
|
"eta": 0,
|
||||||
|
"s_churn": 0,
|
||||||
|
"s_tmax": 0,
|
||||||
|
"s_tmin": 0,
|
||||||
|
"s_noise": 1,
|
||||||
|
"override_settings": {},
|
||||||
|
"sampler_index": "Euler a",
|
||||||
|
"include_init_images": False
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_img2img_simple_performed(self):
|
||||||
|
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
|
||||||
|
|
||||||
|
def test_inpainting_masked_performed(self):
|
||||||
|
self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png"))
|
||||||
|
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
|
||||||
|
|
||||||
|
|
||||||
|
class TestImg2ImgCorrectness(unittest.TestCase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
import unittest
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def run_tests():
|
||||||
|
timeout_threshold = 240
|
||||||
|
start_time = time.time()
|
||||||
|
while time.time()-start_time < timeout_threshold:
|
||||||
|
try:
|
||||||
|
requests.head("http://localhost:7860/")
|
||||||
|
break
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
pass
|
||||||
|
if time.time()-start_time < timeout_threshold:
|
||||||
|
suite = unittest.TestLoader().discover('', pattern='*_test.py')
|
||||||
|
result = unittest.TextTestRunner(verbosity=2).run(suite)
|
||||||
|
else:
|
||||||
|
print("Launch unsuccessful")
|
||||||
Binary file not shown.
|
After Width: | Height: | Size: 9.7 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 362 B |
@ -0,0 +1,74 @@
|
|||||||
|
import unittest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
class TestTxt2ImgWorking(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img"
|
||||||
|
self.simple_txt2img = {
|
||||||
|
"enable_hr": False,
|
||||||
|
"denoising_strength": 0,
|
||||||
|
"firstphase_width": 0,
|
||||||
|
"firstphase_height": 0,
|
||||||
|
"prompt": "example prompt",
|
||||||
|
"styles": [],
|
||||||
|
"seed": -1,
|
||||||
|
"subseed": -1,
|
||||||
|
"subseed_strength": 0,
|
||||||
|
"seed_resize_from_h": -1,
|
||||||
|
"seed_resize_from_w": -1,
|
||||||
|
"batch_size": 1,
|
||||||
|
"n_iter": 1,
|
||||||
|
"steps": 3,
|
||||||
|
"cfg_scale": 7,
|
||||||
|
"width": 64,
|
||||||
|
"height": 64,
|
||||||
|
"restore_faces": False,
|
||||||
|
"tiling": False,
|
||||||
|
"negative_prompt": "",
|
||||||
|
"eta": 0,
|
||||||
|
"s_churn": 0,
|
||||||
|
"s_tmax": 0,
|
||||||
|
"s_tmin": 0,
|
||||||
|
"s_noise": 1,
|
||||||
|
"sampler_index": "Euler a"
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_txt2img_simple_performed(self):
|
||||||
|
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
|
||||||
|
|
||||||
|
def test_txt2img_with_negative_prompt_performed(self):
|
||||||
|
self.simple_txt2img["negative_prompt"] = "example negative prompt"
|
||||||
|
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
|
||||||
|
|
||||||
|
def test_txt2img_not_square_image_performed(self):
|
||||||
|
self.simple_txt2img["height"] = 128
|
||||||
|
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
|
||||||
|
|
||||||
|
def test_txt2img_with_hrfix_performed(self):
|
||||||
|
self.simple_txt2img["enable_hr"] = True
|
||||||
|
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
|
||||||
|
|
||||||
|
def test_txt2img_with_restore_faces_performed(self):
|
||||||
|
self.simple_txt2img["restore_faces"] = True
|
||||||
|
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
|
||||||
|
|
||||||
|
def test_txt2img_with_tiling_faces_performed(self):
|
||||||
|
self.simple_txt2img["tiling"] = True
|
||||||
|
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
|
||||||
|
|
||||||
|
def test_txt2img_with_vanilla_sampler_performed(self):
|
||||||
|
self.simple_txt2img["sampler_index"] = "PLMS"
|
||||||
|
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
|
||||||
|
|
||||||
|
def test_txt2img_multiple_batches_performed(self):
|
||||||
|
self.simple_txt2img["n_iter"] = 2
|
||||||
|
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTxt2ImgCorrectness(unittest.TestCase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in New Issue