Merge branch 'AUTOMATIC1111:master' into master
commit
25de9df364
@ -1 +1,13 @@
|
||||
* @AUTOMATIC1111
|
||||
/localizations/ar_AR.json @xmodar @blackneoo
|
||||
/localizations/de_DE.json @LunixWasTaken
|
||||
/localizations/es_ES.json @innovaciones
|
||||
/localizations/fr_FR.json @tumbly
|
||||
/localizations/it_IT.json @EugenioBuffo
|
||||
/localizations/ja_JP.json @yuuki76
|
||||
/localizations/ko_KR.json @36DB
|
||||
/localizations/pt_BR.json @M-art-ucci
|
||||
/localizations/ru_RU.json @kabachuha
|
||||
/localizations/tr_TR.json @camenduru
|
||||
/localizations/zh_CN.json @dtlnor @bgluminous
|
||||
/localizations/zh_TW.json @benlisquare
|
||||
|
||||
@ -0,0 +1,24 @@
|
||||
|
||||
function extensions_apply(_, _){
|
||||
disable = []
|
||||
update = []
|
||||
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
||||
if(x.name.startsWith("enable_") && ! x.checked)
|
||||
disable.push(x.name.substr(7))
|
||||
|
||||
if(x.name.startsWith("update_") && x.checked)
|
||||
update.push(x.name.substr(7))
|
||||
})
|
||||
|
||||
restart_reload()
|
||||
|
||||
return [JSON.stringify(disable), JSON.stringify(update)]
|
||||
}
|
||||
|
||||
function extensions_check(){
|
||||
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
||||
x.innerHTML = "Loading..."
|
||||
})
|
||||
|
||||
return []
|
||||
}
|
||||
@ -1,206 +0,0 @@
|
||||
var images_history_click_image = function(){
|
||||
if (!this.classList.contains("transform")){
|
||||
var gallery = images_history_get_parent_by_class(this, "images_history_cantainor");
|
||||
var buttons = gallery.querySelectorAll(".gallery-item");
|
||||
var i = 0;
|
||||
var hidden_list = [];
|
||||
buttons.forEach(function(e){
|
||||
if (e.style.display == "none"){
|
||||
hidden_list.push(i);
|
||||
}
|
||||
i += 1;
|
||||
})
|
||||
if (hidden_list.length > 0){
|
||||
setTimeout(images_history_hide_buttons, 10, hidden_list, gallery);
|
||||
}
|
||||
}
|
||||
images_history_set_image_info(this);
|
||||
}
|
||||
|
||||
var images_history_click_tab = function(){
|
||||
var tabs_box = gradioApp().getElementById("images_history_tab");
|
||||
if (!tabs_box.classList.contains(this.getAttribute("tabname"))) {
|
||||
gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_renew_page").click();
|
||||
tabs_box.classList.add(this.getAttribute("tabname"))
|
||||
}
|
||||
}
|
||||
|
||||
function images_history_disabled_del(){
|
||||
gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
|
||||
btn.setAttribute('disabled','disabled');
|
||||
});
|
||||
}
|
||||
|
||||
function images_history_get_parent_by_class(item, class_name){
|
||||
var parent = item.parentElement;
|
||||
while(!parent.classList.contains(class_name)){
|
||||
parent = parent.parentElement;
|
||||
}
|
||||
return parent;
|
||||
}
|
||||
|
||||
function images_history_get_parent_by_tagname(item, tagname){
|
||||
var parent = item.parentElement;
|
||||
tagname = tagname.toUpperCase()
|
||||
while(parent.tagName != tagname){
|
||||
console.log(parent.tagName, tagname)
|
||||
parent = parent.parentElement;
|
||||
}
|
||||
return parent;
|
||||
}
|
||||
|
||||
function images_history_hide_buttons(hidden_list, gallery){
|
||||
var buttons = gallery.querySelectorAll(".gallery-item");
|
||||
var num = 0;
|
||||
buttons.forEach(function(e){
|
||||
if (e.style.display == "none"){
|
||||
num += 1;
|
||||
}
|
||||
});
|
||||
if (num == hidden_list.length){
|
||||
setTimeout(images_history_hide_buttons, 10, hidden_list, gallery);
|
||||
}
|
||||
for( i in hidden_list){
|
||||
buttons[hidden_list[i]].style.display = "none";
|
||||
}
|
||||
}
|
||||
|
||||
function images_history_set_image_info(button){
|
||||
var buttons = images_history_get_parent_by_tagname(button, "DIV").querySelectorAll(".gallery-item");
|
||||
var index = -1;
|
||||
var i = 0;
|
||||
buttons.forEach(function(e){
|
||||
if(e == button){
|
||||
index = i;
|
||||
}
|
||||
if(e.style.display != "none"){
|
||||
i += 1;
|
||||
}
|
||||
});
|
||||
var gallery = images_history_get_parent_by_class(button, "images_history_cantainor");
|
||||
var set_btn = gallery.querySelector(".images_history_set_index");
|
||||
var curr_idx = set_btn.getAttribute("img_index", index);
|
||||
if (curr_idx != index) {
|
||||
set_btn.setAttribute("img_index", index);
|
||||
images_history_disabled_del();
|
||||
}
|
||||
set_btn.click();
|
||||
|
||||
}
|
||||
|
||||
function images_history_get_current_img(tabname, image_path, files){
|
||||
return [
|
||||
gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"),
|
||||
image_path,
|
||||
files
|
||||
];
|
||||
}
|
||||
|
||||
function images_history_delete(del_num, tabname, img_path, img_file_name, page_index, filenames, image_index){
|
||||
image_index = parseInt(image_index);
|
||||
var tab = gradioApp().getElementById(tabname + '_images_history');
|
||||
var set_btn = tab.querySelector(".images_history_set_index");
|
||||
var buttons = [];
|
||||
tab.querySelectorAll(".gallery-item").forEach(function(e){
|
||||
if (e.style.display != 'none'){
|
||||
buttons.push(e);
|
||||
}
|
||||
});
|
||||
var img_num = buttons.length / 2;
|
||||
if (img_num <= del_num){
|
||||
setTimeout(function(tabname){
|
||||
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
|
||||
}, 30, tabname);
|
||||
} else {
|
||||
var next_img
|
||||
for (var i = 0; i < del_num; i++){
|
||||
if (image_index + i < image_index + img_num){
|
||||
buttons[image_index + i].style.display = 'none';
|
||||
buttons[image_index + img_num + 1].style.display = 'none';
|
||||
next_img = image_index + i + 1
|
||||
}
|
||||
}
|
||||
var bnt;
|
||||
if (next_img >= img_num){
|
||||
btn = buttons[image_index - del_num];
|
||||
} else {
|
||||
btn = buttons[next_img];
|
||||
}
|
||||
setTimeout(function(btn){btn.click()}, 30, btn);
|
||||
}
|
||||
images_history_disabled_del();
|
||||
return [del_num, tabname, img_path, img_file_name, page_index, filenames, image_index];
|
||||
}
|
||||
|
||||
function images_history_turnpage(img_path, page_index, image_index, tabname){
|
||||
var buttons = gradioApp().getElementById(tabname + '_images_history').querySelectorAll(".gallery-item");
|
||||
buttons.forEach(function(elem) {
|
||||
elem.style.display = 'block';
|
||||
})
|
||||
return [img_path, page_index, image_index, tabname];
|
||||
}
|
||||
|
||||
function images_history_enable_del_buttons(){
|
||||
gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
|
||||
btn.removeAttribute('disabled');
|
||||
})
|
||||
}
|
||||
|
||||
function images_history_init(){
|
||||
var load_txt2img_button = gradioApp().getElementById('txt2img_images_history_renew_page')
|
||||
if (load_txt2img_button){
|
||||
for (var i in images_history_tab_list ){
|
||||
tab = images_history_tab_list[i];
|
||||
gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor");
|
||||
gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index");
|
||||
gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button");
|
||||
gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery");
|
||||
|
||||
}
|
||||
var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div");
|
||||
tabs_box.setAttribute("id", "images_history_tab");
|
||||
var tab_btns = tabs_box.querySelectorAll("button");
|
||||
for (var i in images_history_tab_list){
|
||||
var tabname = images_history_tab_list[i]
|
||||
tab_btns[i].setAttribute("tabname", tabname);
|
||||
|
||||
// this refreshes history upon tab switch
|
||||
// until the history is known to work well, which is not the case now, we do not do this at startup
|
||||
//tab_btns[i].addEventListener('click', images_history_click_tab);
|
||||
}
|
||||
tabs_box.classList.add(images_history_tab_list[0]);
|
||||
|
||||
// same as above, at page load
|
||||
//load_txt2img_button.click();
|
||||
} else {
|
||||
setTimeout(images_history_init, 500);
|
||||
}
|
||||
}
|
||||
|
||||
var images_history_tab_list = ["txt2img", "img2img", "extras"];
|
||||
setTimeout(images_history_init, 500);
|
||||
document.addEventListener("DOMContentLoaded", function() {
|
||||
var mutationObserver = new MutationObserver(function(m){
|
||||
for (var i in images_history_tab_list ){
|
||||
let tabname = images_history_tab_list[i]
|
||||
var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item');
|
||||
buttons.forEach(function(bnt){
|
||||
bnt.addEventListener('click', images_history_click_image, true);
|
||||
});
|
||||
|
||||
// same as load_txt2img_button.click() above
|
||||
/*
|
||||
var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg");
|
||||
if (cls_btn){
|
||||
cls_btn.addEventListener('click', function(){
|
||||
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
|
||||
}, false);
|
||||
}*/
|
||||
|
||||
}
|
||||
});
|
||||
mutationObserver.observe( gradioApp(), { childList:true, subtree:true });
|
||||
|
||||
});
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,241 +0,0 @@
|
||||
import copy
|
||||
import itertools
|
||||
import os
|
||||
from pathlib import Path
|
||||
import html
|
||||
import gc
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import optim
|
||||
|
||||
from modules import shared
|
||||
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer
|
||||
from tqdm.auto import tqdm, trange
|
||||
from modules.shared import opts, device
|
||||
|
||||
|
||||
def get_all_images_in_folder(folder):
|
||||
return [os.path.join(folder, f) for f in os.listdir(folder) if
|
||||
os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)]
|
||||
|
||||
|
||||
def check_is_valid_image_file(filename):
|
||||
return filename.lower().endswith(('.png', '.jpg', '.jpeg', ".gif", ".tiff", ".webp"))
|
||||
|
||||
|
||||
def batched(dataset, total, n=1):
|
||||
for ndx in range(0, total, n):
|
||||
yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]
|
||||
|
||||
|
||||
def iter_to_batched(iterable, n=1):
|
||||
it = iter(iterable)
|
||||
while True:
|
||||
chunk = tuple(itertools.islice(it, n))
|
||||
if not chunk:
|
||||
return
|
||||
yield chunk
|
||||
|
||||
|
||||
def create_ui():
|
||||
import modules.ui
|
||||
|
||||
with gr.Group():
|
||||
with gr.Accordion("Open for Clip Aesthetic!", open=False):
|
||||
with gr.Row():
|
||||
aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight",
|
||||
value=0.9)
|
||||
aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
|
||||
|
||||
with gr.Row():
|
||||
aesthetic_lr = gr.Textbox(label='Aesthetic learning rate',
|
||||
placeholder="Aesthetic learning rate", value="0.0001")
|
||||
aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
|
||||
aesthetic_imgs = gr.Dropdown(sorted(shared.aesthetic_embeddings.keys()),
|
||||
label="Aesthetic imgs embedding",
|
||||
value="None")
|
||||
|
||||
modules.ui.create_refresh_button(aesthetic_imgs, shared.update_aesthetic_embeddings, lambda: {"choices": sorted(shared.aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings")
|
||||
|
||||
with gr.Row():
|
||||
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs',
|
||||
placeholder="This text is used to rotate the feature space of the imgs embs",
|
||||
value="")
|
||||
aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01,
|
||||
value=0.1)
|
||||
aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
|
||||
|
||||
return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative
|
||||
|
||||
|
||||
aesthetic_clip_model = None
|
||||
|
||||
|
||||
def aesthetic_clip():
|
||||
global aesthetic_clip_model
|
||||
|
||||
if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path:
|
||||
aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path)
|
||||
aesthetic_clip_model.cpu()
|
||||
|
||||
return aesthetic_clip_model
|
||||
|
||||
|
||||
def generate_imgs_embd(name, folder, batch_size):
|
||||
model = aesthetic_clip().to(device)
|
||||
processor = CLIPProcessor.from_pretrained(model.name_or_path)
|
||||
|
||||
with torch.no_grad():
|
||||
embs = []
|
||||
for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size),
|
||||
desc=f"Generating embeddings for {name}"):
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device)
|
||||
outputs = model.get_image_features(**inputs).cpu()
|
||||
embs.append(torch.clone(outputs))
|
||||
inputs.to("cpu")
|
||||
del inputs, outputs
|
||||
|
||||
embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)
|
||||
|
||||
# The generated embedding will be located here
|
||||
path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
|
||||
torch.save(embs, path)
|
||||
|
||||
model.cpu()
|
||||
del processor
|
||||
del embs
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
res = f"""
|
||||
Done generating embedding for {name}!
|
||||
Aesthetic embedding saved to {html.escape(path)}
|
||||
"""
|
||||
shared.update_aesthetic_embeddings()
|
||||
return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding",
|
||||
value="None"), \
|
||||
gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()),
|
||||
label="Imgs embedding",
|
||||
value="None"), res, ""
|
||||
|
||||
|
||||
def slerp(low, high, val):
|
||||
low_norm = low / torch.norm(low, dim=1, keepdim=True)
|
||||
high_norm = high / torch.norm(high, dim=1, keepdim=True)
|
||||
omega = torch.acos((low_norm * high_norm).sum(1))
|
||||
so = torch.sin(omega)
|
||||
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
|
||||
return res
|
||||
|
||||
|
||||
class AestheticCLIP:
|
||||
def __init__(self):
|
||||
self.skip = False
|
||||
self.aesthetic_steps = 0
|
||||
self.aesthetic_weight = 0
|
||||
self.aesthetic_lr = 0
|
||||
self.slerp = False
|
||||
self.aesthetic_text_negative = ""
|
||||
self.aesthetic_slerp_angle = 0
|
||||
self.aesthetic_imgs_text = ""
|
||||
|
||||
self.image_embs_name = None
|
||||
self.image_embs = None
|
||||
self.load_image_embs(None)
|
||||
|
||||
def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
|
||||
aesthetic_slerp=True, aesthetic_imgs_text="",
|
||||
aesthetic_slerp_angle=0.15,
|
||||
aesthetic_text_negative=False):
|
||||
self.aesthetic_imgs_text = aesthetic_imgs_text
|
||||
self.aesthetic_slerp_angle = aesthetic_slerp_angle
|
||||
self.aesthetic_text_negative = aesthetic_text_negative
|
||||
self.slerp = aesthetic_slerp
|
||||
self.aesthetic_lr = aesthetic_lr
|
||||
self.aesthetic_weight = aesthetic_weight
|
||||
self.aesthetic_steps = aesthetic_steps
|
||||
self.load_image_embs(image_embs_name)
|
||||
|
||||
if self.image_embs_name is not None:
|
||||
p.extra_generation_params.update({
|
||||
"Aesthetic LR": aesthetic_lr,
|
||||
"Aesthetic weight": aesthetic_weight,
|
||||
"Aesthetic steps": aesthetic_steps,
|
||||
"Aesthetic embedding": self.image_embs_name,
|
||||
"Aesthetic slerp": aesthetic_slerp,
|
||||
"Aesthetic text": aesthetic_imgs_text,
|
||||
"Aesthetic text negative": aesthetic_text_negative,
|
||||
"Aesthetic slerp angle": aesthetic_slerp_angle,
|
||||
})
|
||||
|
||||
def set_skip(self, skip):
|
||||
self.skip = skip
|
||||
|
||||
def load_image_embs(self, image_embs_name):
|
||||
if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None":
|
||||
image_embs_name = None
|
||||
self.image_embs_name = None
|
||||
if image_embs_name is not None and self.image_embs_name != image_embs_name:
|
||||
self.image_embs_name = image_embs_name
|
||||
self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device)
|
||||
self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
|
||||
self.image_embs.requires_grad_(False)
|
||||
|
||||
def __call__(self, z, remade_batch_tokens):
|
||||
if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None:
|
||||
tokenizer = shared.sd_model.cond_stage_model.tokenizer
|
||||
if not opts.use_old_emphasis_implementation:
|
||||
remade_batch_tokens = [
|
||||
[tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in
|
||||
remade_batch_tokens]
|
||||
|
||||
tokens = torch.asarray(remade_batch_tokens).to(device)
|
||||
|
||||
model = copy.deepcopy(aesthetic_clip()).to(device)
|
||||
model.requires_grad_(True)
|
||||
if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
|
||||
text_embs_2 = model.get_text_features(
|
||||
**tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device))
|
||||
if self.aesthetic_text_negative:
|
||||
text_embs_2 = self.image_embs - text_embs_2
|
||||
text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True)
|
||||
img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle)
|
||||
else:
|
||||
img_embs = self.image_embs
|
||||
|
||||
with torch.enable_grad():
|
||||
|
||||
# We optimize the model to maximize the similarity
|
||||
optimizer = optim.Adam(
|
||||
model.text_model.parameters(), lr=self.aesthetic_lr
|
||||
)
|
||||
|
||||
for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"):
|
||||
text_embs = model.get_text_features(input_ids=tokens)
|
||||
text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
|
||||
sim = text_embs @ img_embs.T
|
||||
loss = -sim
|
||||
optimizer.zero_grad()
|
||||
loss.mean().backward()
|
||||
optimizer.step()
|
||||
|
||||
zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||
if opts.CLIP_stop_at_last_layers > 1:
|
||||
zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers]
|
||||
zn = model.text_model.final_layer_norm(zn)
|
||||
else:
|
||||
zn = zn.last_hidden_state
|
||||
model.cpu()
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1)
|
||||
if self.slerp:
|
||||
z = slerp(z, zn, self.aesthetic_weight)
|
||||
else:
|
||||
z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight
|
||||
|
||||
return z
|
||||
@ -0,0 +1,167 @@
|
||||
import inspect
|
||||
from click import prompt
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from typing import Any, Optional
|
||||
from typing_extensions import Literal
|
||||
from inflection import underscore
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
||||
from modules.shared import sd_upscalers
|
||||
|
||||
API_NOT_ALLOWED = [
|
||||
"self",
|
||||
"kwargs",
|
||||
"sd_model",
|
||||
"outpath_samples",
|
||||
"outpath_grids",
|
||||
"sampler_index",
|
||||
"do_not_save_samples",
|
||||
"do_not_save_grid",
|
||||
"extra_generation_params",
|
||||
"overlay_images",
|
||||
"do_not_reload_embeddings",
|
||||
"seed_enable_extras",
|
||||
"prompt_for_display",
|
||||
"sampler_noise_scheduler_override",
|
||||
"ddim_discretize"
|
||||
]
|
||||
|
||||
class ModelDef(BaseModel):
|
||||
"""Assistance Class for Pydantic Dynamic Model Generation"""
|
||||
|
||||
field: str
|
||||
field_alias: str
|
||||
field_type: Any
|
||||
field_value: Any
|
||||
field_exclude: bool = False
|
||||
|
||||
|
||||
class PydanticModelGenerator:
|
||||
"""
|
||||
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
|
||||
source_data is a snapshot of the default values produced by the class
|
||||
params are the names of the actual keys required by __init__
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = None,
|
||||
class_instance = None,
|
||||
additional_fields = None,
|
||||
):
|
||||
def field_type_generator(k, v):
|
||||
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
||||
# print(k, v.annotation, v.default)
|
||||
field_type = v.annotation
|
||||
|
||||
return Optional[field_type]
|
||||
|
||||
def merge_class_params(class_):
|
||||
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
|
||||
parameters = {}
|
||||
for classes in all_classes:
|
||||
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
||||
return parameters
|
||||
|
||||
|
||||
self._model_name = model_name
|
||||
self._class_data = merge_class_params(class_instance)
|
||||
self._model_def = [
|
||||
ModelDef(
|
||||
field=underscore(k),
|
||||
field_alias=k,
|
||||
field_type=field_type_generator(k, v),
|
||||
field_value=v.default
|
||||
)
|
||||
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||
]
|
||||
|
||||
for fields in additional_fields:
|
||||
self._model_def.append(ModelDef(
|
||||
field=underscore(fields["key"]),
|
||||
field_alias=fields["key"],
|
||||
field_type=fields["type"],
|
||||
field_value=fields["default"],
|
||||
field_exclude=fields["exclude"] if "exclude" in fields else False))
|
||||
|
||||
def generate_model(self):
|
||||
"""
|
||||
Creates a pydantic BaseModel
|
||||
from the json and overrides provided at initialization
|
||||
"""
|
||||
fields = {
|
||||
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
|
||||
}
|
||||
DynamicModel = create_model(self._model_name, **fields)
|
||||
DynamicModel.__config__.allow_population_by_field_name = True
|
||||
DynamicModel.__config__.allow_mutation = True
|
||||
return DynamicModel
|
||||
|
||||
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingTxt2Img",
|
||||
StableDiffusionProcessingTxt2Img,
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}]
|
||||
).generate_model()
|
||||
|
||||
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingImg2Img",
|
||||
StableDiffusionProcessingImg2Img,
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
|
||||
).generate_model()
|
||||
|
||||
class TextToImageResponse(BaseModel):
|
||||
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: dict
|
||||
info: str
|
||||
|
||||
class ImageToImageResponse(BaseModel):
|
||||
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: dict
|
||||
info: str
|
||||
|
||||
class ExtrasBaseRequest(BaseModel):
|
||||
resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
|
||||
show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
|
||||
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
|
||||
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
|
||||
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
|
||||
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
|
||||
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
|
||||
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
|
||||
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?")
|
||||
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
|
||||
|
||||
class ExtraBaseResponse(BaseModel):
|
||||
html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
|
||||
|
||||
class ExtrasSingleImageRequest(ExtrasBaseRequest):
|
||||
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||
|
||||
class ExtrasSingleImageResponse(ExtraBaseResponse):
|
||||
image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
|
||||
class FileData(BaseModel):
|
||||
data: str = Field(title="File data", description="Base64 representation of the file")
|
||||
name: str = Field(title="File name")
|
||||
|
||||
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
|
||||
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
|
||||
|
||||
class ExtrasBatchImagesResponse(ExtraBaseResponse):
|
||||
images: list[str] = Field(title="Images", description="The generated images in base64 format.")
|
||||
|
||||
class PNGInfoRequest(BaseModel):
|
||||
image: str = Field(title="Image", description="The base64 encoded PNG image")
|
||||
|
||||
class PNGInfoResponse(BaseModel):
|
||||
info: str = Field(title="Image info", description="A string with all the info the image had")
|
||||
|
||||
class ProgressRequest(BaseModel):
|
||||
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
|
||||
|
||||
class ProgressResponse(BaseModel):
|
||||
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
|
||||
eta_relative: float = Field(title="ETA in secs")
|
||||
state: dict = Field(title="State", description="The current state snapshot")
|
||||
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
||||
@ -1,99 +0,0 @@
|
||||
from inflection import underscore
|
||||
from typing import Any, Dict, Optional
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img
|
||||
import inspect
|
||||
|
||||
|
||||
API_NOT_ALLOWED = [
|
||||
"self",
|
||||
"kwargs",
|
||||
"sd_model",
|
||||
"outpath_samples",
|
||||
"outpath_grids",
|
||||
"sampler_index",
|
||||
"do_not_save_samples",
|
||||
"do_not_save_grid",
|
||||
"extra_generation_params",
|
||||
"overlay_images",
|
||||
"do_not_reload_embeddings",
|
||||
"seed_enable_extras",
|
||||
"prompt_for_display",
|
||||
"sampler_noise_scheduler_override",
|
||||
"ddim_discretize"
|
||||
]
|
||||
|
||||
class ModelDef(BaseModel):
|
||||
"""Assistance Class for Pydantic Dynamic Model Generation"""
|
||||
|
||||
field: str
|
||||
field_alias: str
|
||||
field_type: Any
|
||||
field_value: Any
|
||||
|
||||
|
||||
class PydanticModelGenerator:
|
||||
"""
|
||||
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
|
||||
source_data is a snapshot of the default values produced by the class
|
||||
params are the names of the actual keys required by __init__
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = None,
|
||||
class_instance = None,
|
||||
additional_fields = None,
|
||||
):
|
||||
def field_type_generator(k, v):
|
||||
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
||||
# print(k, v.annotation, v.default)
|
||||
field_type = v.annotation
|
||||
|
||||
return Optional[field_type]
|
||||
|
||||
def merge_class_params(class_):
|
||||
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
|
||||
parameters = {}
|
||||
for classes in all_classes:
|
||||
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
||||
return parameters
|
||||
|
||||
|
||||
self._model_name = model_name
|
||||
self._class_data = merge_class_params(class_instance)
|
||||
self._model_def = [
|
||||
ModelDef(
|
||||
field=underscore(k),
|
||||
field_alias=k,
|
||||
field_type=field_type_generator(k, v),
|
||||
field_value=v.default
|
||||
)
|
||||
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||
]
|
||||
|
||||
for fields in additional_fields:
|
||||
self._model_def.append(ModelDef(
|
||||
field=underscore(fields["key"]),
|
||||
field_alias=fields["key"],
|
||||
field_type=fields["type"],
|
||||
field_value=fields["default"]))
|
||||
|
||||
def generate_model(self):
|
||||
"""
|
||||
Creates a pydantic BaseModel
|
||||
from the json and overrides provided at initialization
|
||||
"""
|
||||
fields = {
|
||||
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
|
||||
}
|
||||
DynamicModel = create_model(self._model_name, **fields)
|
||||
DynamicModel.__config__.allow_population_by_field_name = True
|
||||
DynamicModel.__config__.allow_mutation = True
|
||||
return DynamicModel
|
||||
|
||||
StableDiffusionProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingTxt2Img",
|
||||
StableDiffusionProcessingTxt2Img,
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}]
|
||||
).generate_model()
|
||||
@ -1,76 +0,0 @@
|
||||
import os.path
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
import torch
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.upscaler
|
||||
from modules import devices, modelloader
|
||||
from modules.bsrgan_model_arch import RRDBNet
|
||||
|
||||
|
||||
class UpscalerBSRGAN(modules.upscaler.Upscaler):
|
||||
def __init__(self, dirname):
|
||||
self.name = "BSRGAN"
|
||||
self.model_name = "BSRGAN 4x"
|
||||
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
|
||||
self.user_path = dirname
|
||||
super().__init__()
|
||||
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
||||
scalers = []
|
||||
if len(model_paths) == 0:
|
||||
scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
|
||||
scalers.append(scaler_data)
|
||||
for file in model_paths:
|
||||
if "http" in file:
|
||||
name = self.model_name
|
||||
else:
|
||||
name = modelloader.friendly_name(file)
|
||||
try:
|
||||
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
||||
scalers.append(scaler_data)
|
||||
except Exception:
|
||||
print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
self.scalers = scalers
|
||||
|
||||
def do_upscale(self, img: PIL.Image, selected_file):
|
||||
torch.cuda.empty_cache()
|
||||
model = self.load_model(selected_file)
|
||||
if model is None:
|
||||
return img
|
||||
model.to(devices.device_bsrgan)
|
||||
torch.cuda.empty_cache()
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.moveaxis(img, 2, 0) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(devices.device_bsrgan)
|
||||
with torch.no_grad():
|
||||
output = model(img)
|
||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output = 255. * np.moveaxis(output, 0, 2)
|
||||
output = output.astype(np.uint8)
|
||||
output = output[:, :, ::-1]
|
||||
torch.cuda.empty_cache()
|
||||
return PIL.Image.fromarray(output, 'RGB')
|
||||
|
||||
def load_model(self, path: str):
|
||||
if "http" in path:
|
||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
||||
progress=True)
|
||||
else:
|
||||
filename = path
|
||||
if not os.path.exists(filename) or filename is None:
|
||||
print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
|
||||
return None
|
||||
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
|
||||
model.load_state_dict(torch.load(filename), strict=True)
|
||||
model.eval()
|
||||
for k, v in model.named_parameters():
|
||||
v.requires_grad = False
|
||||
return model
|
||||
|
||||
@ -1,102 +0,0 @@
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
|
||||
|
||||
def initialize_weights(net_l, scale=1):
|
||||
if not isinstance(net_l, list):
|
||||
net_l = [net_l]
|
||||
for net in net_l:
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
m.weight.data *= scale # for residual block
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
m.weight.data *= scale
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
|
||||
def make_layer(block, n_layers):
|
||||
layers = []
|
||||
for _ in range(n_layers):
|
||||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
# initialization
|
||||
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
'''Residual in Residual Dense Block'''
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
self.sf = sf
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
if self.sf==4:
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
fea = fea + trunk
|
||||
|
||||
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
if self.sf==4:
|
||||
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
return out
|
||||
@ -1,80 +1,463 @@
|
||||
# this file is taken from https://github.com/xinntao/ESRGAN
|
||||
# this file is adapted from https://github.com/victorca25/iNNfer
|
||||
|
||||
import math
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def make_layer(block, n_layers):
|
||||
layers = []
|
||||
for _ in range(n_layers):
|
||||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
####################
|
||||
# RRDBNet Generator
|
||||
####################
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
|
||||
act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
|
||||
finalact=None, gaussian_noise=False, plus=False):
|
||||
super(RRDBNet, self).__init__()
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
if upscale == 3:
|
||||
n_upscale = 1
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
self.resrgan_scale = 0
|
||||
if in_nc % 16 == 0:
|
||||
self.resrgan_scale = 1
|
||||
elif in_nc != 4 and in_nc % 4 == 0:
|
||||
self.resrgan_scale = 2
|
||||
|
||||
# initialization
|
||||
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
||||
rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||
norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
|
||||
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
|
||||
LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
if upsample_mode == 'upconv':
|
||||
upsample_block = upconv_block
|
||||
elif upsample_mode == 'pixelshuffle':
|
||||
upsample_block = pixelshuffle_block
|
||||
else:
|
||||
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
||||
if upscale == 3:
|
||||
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
||||
else:
|
||||
upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
|
||||
HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
|
||||
HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
||||
|
||||
outact = act(finalact) if finalact else None
|
||||
|
||||
self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
|
||||
*upsampler, HR_conv0, HR_conv1, outact)
|
||||
|
||||
def forward(self, x, outm=None):
|
||||
if self.resrgan_scale == 1:
|
||||
feat = pixel_unshuffle(x, scale=4)
|
||||
elif self.resrgan_scale == 2:
|
||||
feat = pixel_unshuffle(x, scale=2)
|
||||
else:
|
||||
feat = x
|
||||
|
||||
return self.model(feat)
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
'''Residual in Residual Dense Block'''
|
||||
"""
|
||||
Residual in Residual Dense Block
|
||||
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
||||
"""
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False, gaussian_noise=False, plus=False):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
# This is for backwards compatibility with existing models
|
||||
if nr == 3:
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus)
|
||||
else:
|
||||
RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
|
||||
self.RDBs = nn.Sequential(*RDB_list)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
if hasattr(self, 'RDB1'):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
else:
|
||||
out = self.RDBs(x)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
"""
|
||||
Residual Dense Block
|
||||
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
||||
Modified options that can be used:
|
||||
- "Partial Convolution based Padding" arXiv:1811.11718
|
||||
- "Spectral normalization" arXiv:1802.05957
|
||||
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
||||
{Rakotonirina} and A. {Rasoanaivo}
|
||||
"""
|
||||
|
||||
def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False, gaussian_noise=False, plus=False):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
self.noise = GaussianNoise() if gaussian_noise else None
|
||||
self.conv1x1 = conv1x1(nf, gc) if plus else None
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
if mode == 'CNA':
|
||||
last_act = None
|
||||
else:
|
||||
last_act = act_type
|
||||
self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
fea = fea + trunk
|
||||
x1 = self.conv1(x)
|
||||
x2 = self.conv2(torch.cat((x, x1), 1))
|
||||
if self.conv1x1:
|
||||
x2 = x2 + self.conv1x1(x)
|
||||
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
||||
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
||||
if self.conv1x1:
|
||||
x4 = x4 + x2
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
if self.noise:
|
||||
return self.noise(x5.mul(0.2) + x)
|
||||
else:
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
####################
|
||||
# ESRGANplus
|
||||
####################
|
||||
|
||||
class GaussianNoise(nn.Module):
|
||||
def __init__(self, sigma=0.1, is_relative_detach=False):
|
||||
super().__init__()
|
||||
self.sigma = sigma
|
||||
self.is_relative_detach = is_relative_detach
|
||||
self.noise = torch.tensor(0, dtype=torch.float)
|
||||
|
||||
def forward(self, x):
|
||||
if self.training and self.sigma != 0:
|
||||
self.noise = self.noise.to(x.device)
|
||||
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
||||
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
||||
x = x + sampled_noise
|
||||
return x
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
####################
|
||||
# SRVGGNetCompact
|
||||
####################
|
||||
|
||||
class SRVGGNetCompact(nn.Module):
|
||||
"""A compact VGG-style network structure for super-resolution.
|
||||
This class is copied from https://github.com/xinntao/Real-ESRGAN
|
||||
"""
|
||||
|
||||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
||||
super(SRVGGNetCompact, self).__init__()
|
||||
self.num_in_ch = num_in_ch
|
||||
self.num_out_ch = num_out_ch
|
||||
self.num_feat = num_feat
|
||||
self.num_conv = num_conv
|
||||
self.upscale = upscale
|
||||
self.act_type = act_type
|
||||
|
||||
self.body = nn.ModuleList()
|
||||
# the first conv
|
||||
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
||||
# the first activation
|
||||
if act_type == 'relu':
|
||||
activation = nn.ReLU(inplace=True)
|
||||
elif act_type == 'prelu':
|
||||
activation = nn.PReLU(num_parameters=num_feat)
|
||||
elif act_type == 'leakyrelu':
|
||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.body.append(activation)
|
||||
|
||||
# the body structure
|
||||
for _ in range(num_conv):
|
||||
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
||||
# activation
|
||||
if act_type == 'relu':
|
||||
activation = nn.ReLU(inplace=True)
|
||||
elif act_type == 'prelu':
|
||||
activation = nn.PReLU(num_parameters=num_feat)
|
||||
elif act_type == 'leakyrelu':
|
||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.body.append(activation)
|
||||
|
||||
# the last conv
|
||||
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
||||
# upsample
|
||||
self.upsampler = nn.PixelShuffle(upscale)
|
||||
|
||||
def forward(self, x):
|
||||
out = x
|
||||
for i in range(0, len(self.body)):
|
||||
out = self.body[i](out)
|
||||
|
||||
out = self.upsampler(out)
|
||||
# add the nearest upsampled image, so that the network learns the residual
|
||||
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
||||
out += base
|
||||
return out
|
||||
|
||||
|
||||
####################
|
||||
# Upsampler
|
||||
####################
|
||||
|
||||
class Upsample(nn.Module):
|
||||
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
|
||||
The input data is assumed to be of the form
|
||||
`minibatch x channels x [optional depth] x [optional height] x width`.
|
||||
"""
|
||||
|
||||
def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||
super(Upsample, self).__init__()
|
||||
if isinstance(scale_factor, tuple):
|
||||
self.scale_factor = tuple(float(factor) for factor in scale_factor)
|
||||
else:
|
||||
self.scale_factor = float(scale_factor) if scale_factor else None
|
||||
self.mode = mode
|
||||
self.size = size
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
||||
|
||||
def extra_repr(self):
|
||||
if self.scale_factor is not None:
|
||||
info = 'scale_factor=' + str(self.scale_factor)
|
||||
else:
|
||||
info = 'size=' + str(self.size)
|
||||
info += ', mode=' + self.mode
|
||||
return info
|
||||
|
||||
|
||||
def pixel_unshuffle(x, scale):
|
||||
""" Pixel unshuffle.
|
||||
Args:
|
||||
x (Tensor): Input feature with shape (b, c, hh, hw).
|
||||
scale (int): Downsample ratio.
|
||||
Returns:
|
||||
Tensor: the pixel unshuffled feature.
|
||||
"""
|
||||
b, c, hh, hw = x.size()
|
||||
out_channel = c * (scale**2)
|
||||
assert hh % scale == 0 and hw % scale == 0
|
||||
h = hh // scale
|
||||
w = hw // scale
|
||||
x_view = x.view(b, c, h, scale, w, scale)
|
||||
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
||||
|
||||
|
||||
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
||||
pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
|
||||
"""
|
||||
Pixel shuffle layer
|
||||
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
||||
Neural Network, CVPR17)
|
||||
"""
|
||||
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
|
||||
pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
|
||||
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
a = act(act_type) if act_type else None
|
||||
return sequential(conv, pixel_shuffle, n, a)
|
||||
|
||||
|
||||
def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
||||
pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
|
||||
""" Upconv layer """
|
||||
upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
|
||||
upsample = Upsample(scale_factor=upscale_factor, mode=mode)
|
||||
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
|
||||
pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
|
||||
return sequential(upsample, conv)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
####################
|
||||
# Basic blocks
|
||||
####################
|
||||
|
||||
|
||||
def make_layer(basic_block, num_basic_block, **kwarg):
|
||||
"""Make layers by stacking the same blocks.
|
||||
Args:
|
||||
basic_block (nn.module): nn.module class for basic block. (block)
|
||||
num_basic_block (int): number of blocks. (n_layers)
|
||||
Returns:
|
||||
nn.Sequential: Stacked blocks in nn.Sequential.
|
||||
"""
|
||||
layers = []
|
||||
for _ in range(num_basic_block):
|
||||
layers.append(basic_block(**kwarg))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
|
||||
""" activation helper """
|
||||
act_type = act_type.lower()
|
||||
if act_type == 'relu':
|
||||
layer = nn.ReLU(inplace)
|
||||
elif act_type in ('leakyrelu', 'lrelu'):
|
||||
layer = nn.LeakyReLU(neg_slope, inplace)
|
||||
elif act_type == 'prelu':
|
||||
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
||||
elif act_type == 'tanh': # [-1, 1] range output
|
||||
layer = nn.Tanh()
|
||||
elif act_type == 'sigmoid': # [0, 1] range output
|
||||
layer = nn.Sigmoid()
|
||||
else:
|
||||
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
|
||||
return layer
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self, *kwargs):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x, *kwargs):
|
||||
return x
|
||||
|
||||
|
||||
def norm(norm_type, nc):
|
||||
""" Return a normalization layer """
|
||||
norm_type = norm_type.lower()
|
||||
if norm_type == 'batch':
|
||||
layer = nn.BatchNorm2d(nc, affine=True)
|
||||
elif norm_type == 'instance':
|
||||
layer = nn.InstanceNorm2d(nc, affine=False)
|
||||
elif norm_type == 'none':
|
||||
def norm_layer(x): return Identity()
|
||||
else:
|
||||
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
|
||||
return layer
|
||||
|
||||
|
||||
def pad(pad_type, padding):
|
||||
""" padding layer helper """
|
||||
pad_type = pad_type.lower()
|
||||
if padding == 0:
|
||||
return None
|
||||
if pad_type == 'reflect':
|
||||
layer = nn.ReflectionPad2d(padding)
|
||||
elif pad_type == 'replicate':
|
||||
layer = nn.ReplicationPad2d(padding)
|
||||
elif pad_type == 'zero':
|
||||
layer = nn.ZeroPad2d(padding)
|
||||
else:
|
||||
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
|
||||
return layer
|
||||
|
||||
|
||||
def get_valid_padding(kernel_size, dilation):
|
||||
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
||||
padding = (kernel_size - 1) // 2
|
||||
return padding
|
||||
|
||||
|
||||
class ShortcutBlock(nn.Module):
|
||||
""" Elementwise sum the output of a submodule to its input """
|
||||
def __init__(self, submodule):
|
||||
super(ShortcutBlock, self).__init__()
|
||||
self.sub = submodule
|
||||
|
||||
def forward(self, x):
|
||||
output = x + self.sub(x)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
|
||||
|
||||
|
||||
def sequential(*args):
|
||||
""" Flatten Sequential. It unwraps nn.Sequential. """
|
||||
if len(args) == 1:
|
||||
if isinstance(args[0], OrderedDict):
|
||||
raise NotImplementedError('sequential does not support OrderedDict input.')
|
||||
return args[0] # No sequential is needed.
|
||||
modules = []
|
||||
for module in args:
|
||||
if isinstance(module, nn.Sequential):
|
||||
for submodule in module.children():
|
||||
modules.append(submodule)
|
||||
elif isinstance(module, nn.Module):
|
||||
modules.append(module)
|
||||
return nn.Sequential(*modules)
|
||||
|
||||
|
||||
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
|
||||
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False):
|
||||
""" Conv layer with padding, normalization, activation """
|
||||
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
|
||||
padding = get_valid_padding(kernel_size, dilation)
|
||||
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
||||
padding = padding if pad_type == 'zero' else 0
|
||||
|
||||
if convtype=='PartialConv2D':
|
||||
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
elif convtype=='DeformConv2D':
|
||||
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
elif convtype=='Conv3D':
|
||||
c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
else:
|
||||
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
|
||||
if spectral_norm:
|
||||
c = nn.utils.spectral_norm(c)
|
||||
|
||||
a = act(act_type) if act_type else None
|
||||
if 'CNA' in mode:
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
return sequential(p, c, n, a)
|
||||
elif mode == 'NAC':
|
||||
if norm_type is None and act_type is not None:
|
||||
a = act(act_type, inplace=False)
|
||||
n = norm(norm_type, in_nc) if norm_type else None
|
||||
return sequential(n, a, p, c)
|
||||
|
||||
@ -0,0 +1,83 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import git
|
||||
|
||||
from modules import paths, shared
|
||||
|
||||
|
||||
extensions = []
|
||||
extensions_dir = os.path.join(paths.script_path, "extensions")
|
||||
|
||||
|
||||
def active():
|
||||
return [x for x in extensions if x.enabled]
|
||||
|
||||
|
||||
class Extension:
|
||||
def __init__(self, name, path, enabled=True):
|
||||
self.name = name
|
||||
self.path = path
|
||||
self.enabled = enabled
|
||||
self.status = ''
|
||||
self.can_update = False
|
||||
|
||||
repo = None
|
||||
try:
|
||||
if os.path.exists(os.path.join(path, ".git")):
|
||||
repo = git.Repo(path)
|
||||
except Exception:
|
||||
print(f"Error reading github repository info from {path}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if repo is None or repo.bare:
|
||||
self.remote = None
|
||||
else:
|
||||
self.remote = next(repo.remote().urls, None)
|
||||
self.status = 'unknown'
|
||||
|
||||
def list_files(self, subdir, extension):
|
||||
from modules import scripts
|
||||
|
||||
dirpath = os.path.join(self.path, subdir)
|
||||
if not os.path.isdir(dirpath):
|
||||
return []
|
||||
|
||||
res = []
|
||||
for filename in sorted(os.listdir(dirpath)):
|
||||
res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
|
||||
|
||||
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||
|
||||
return res
|
||||
|
||||
def check_updates(self):
|
||||
repo = git.Repo(self.path)
|
||||
for fetch in repo.remote().fetch("--dry-run"):
|
||||
if fetch.flags != fetch.HEAD_UPTODATE:
|
||||
self.can_update = True
|
||||
self.status = "behind"
|
||||
return
|
||||
|
||||
self.can_update = False
|
||||
self.status = "latest"
|
||||
|
||||
def pull(self):
|
||||
repo = git.Repo(self.path)
|
||||
repo.remotes.origin.pull()
|
||||
|
||||
|
||||
def list_extensions():
|
||||
extensions.clear()
|
||||
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
for dirname in sorted(os.listdir(extensions_dir)):
|
||||
path = os.path.join(extensions_dir, dirname)
|
||||
if not os.path.isdir(path):
|
||||
continue
|
||||
|
||||
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
|
||||
extensions.append(extension)
|
||||
@ -1,183 +0,0 @@
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
def traverse_all_files(output_dir, image_list, curr_dir=None):
|
||||
curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
|
||||
try:
|
||||
f_list = os.listdir(curr_path)
|
||||
except:
|
||||
if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt":
|
||||
image_list.append(curr_dir)
|
||||
return image_list
|
||||
for file in f_list:
|
||||
file = file if curr_dir is None else os.path.join(curr_dir, file)
|
||||
file_path = os.path.join(curr_path, file)
|
||||
if file[-4:] == ".txt":
|
||||
pass
|
||||
elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0:
|
||||
image_list.append(file)
|
||||
else:
|
||||
image_list = traverse_all_files(output_dir, image_list, file)
|
||||
return image_list
|
||||
|
||||
|
||||
def get_recent_images(dir_name, page_index, step, image_index, tabname):
|
||||
page_index = int(page_index)
|
||||
image_list = []
|
||||
if not os.path.exists(dir_name):
|
||||
pass
|
||||
elif os.path.isdir(dir_name):
|
||||
image_list = traverse_all_files(dir_name, image_list)
|
||||
image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file)))
|
||||
else:
|
||||
print(f'ERROR: "{dir_name}" is not a directory. Check the path in the settings.', file=sys.stderr)
|
||||
num = 48 if tabname != "extras" else 12
|
||||
max_page_index = len(image_list) // num + 1
|
||||
page_index = max_page_index if page_index == -1 else page_index + step
|
||||
page_index = 1 if page_index < 1 else page_index
|
||||
page_index = max_page_index if page_index > max_page_index else page_index
|
||||
idx_frm = (page_index - 1) * num
|
||||
image_list = image_list[idx_frm:idx_frm + num]
|
||||
image_index = int(image_index)
|
||||
if image_index < 0 or image_index > len(image_list) - 1:
|
||||
current_file = None
|
||||
hidden = None
|
||||
else:
|
||||
current_file = image_list[int(image_index)]
|
||||
hidden = os.path.join(dir_name, current_file)
|
||||
return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, ""
|
||||
|
||||
|
||||
def first_page_click(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, 1, 0, image_index, tabname)
|
||||
|
||||
|
||||
def end_page_click(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, -1, 0, image_index, tabname)
|
||||
|
||||
|
||||
def prev_page_click(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, page_index, -1, image_index, tabname)
|
||||
|
||||
|
||||
def next_page_click(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, page_index, 1, image_index, tabname)
|
||||
|
||||
|
||||
def page_index_change(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, page_index, 0, image_index, tabname)
|
||||
|
||||
|
||||
def show_image_info(num, image_path, filenames):
|
||||
# print(f"select image {num}")
|
||||
file = filenames[int(num)]
|
||||
return file, num, os.path.join(image_path, file)
|
||||
|
||||
|
||||
def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index):
|
||||
if name == "":
|
||||
return filenames, delete_num
|
||||
else:
|
||||
delete_num = int(delete_num)
|
||||
index = list(filenames).index(name)
|
||||
i = 0
|
||||
new_file_list = []
|
||||
for name in filenames:
|
||||
if i >= index and i < index + delete_num:
|
||||
path = os.path.join(dir_name, name)
|
||||
if os.path.exists(path):
|
||||
print(f"Delete file {path}")
|
||||
os.remove(path)
|
||||
txt_file = os.path.splitext(path)[0] + ".txt"
|
||||
if os.path.exists(txt_file):
|
||||
os.remove(txt_file)
|
||||
else:
|
||||
print(f"Not exists file {path}")
|
||||
else:
|
||||
new_file_list.append(name)
|
||||
i += 1
|
||||
return new_file_list, 1
|
||||
|
||||
|
||||
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
|
||||
if opts.outdir_samples != "":
|
||||
dir_name = opts.outdir_samples
|
||||
elif tabname == "txt2img":
|
||||
dir_name = opts.outdir_txt2img_samples
|
||||
elif tabname == "img2img":
|
||||
dir_name = opts.outdir_img2img_samples
|
||||
elif tabname == "extras":
|
||||
dir_name = opts.outdir_extras_samples
|
||||
else:
|
||||
return
|
||||
with gr.Row():
|
||||
renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page")
|
||||
first_page = gr.Button('First Page')
|
||||
prev_page = gr.Button('Prev Page')
|
||||
page_index = gr.Number(value=1, label="Page Index")
|
||||
next_page = gr.Button('Next Page')
|
||||
end_page = gr.Button('End Page')
|
||||
with gr.Row(elem_id=tabname + "_images_history"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
|
||||
with gr.Row():
|
||||
delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
|
||||
delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
|
||||
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
img_file_info = gr.Textbox(label="Generate Info", interactive=False)
|
||||
img_file_name = gr.Textbox(label="File Name", interactive=False)
|
||||
with gr.Row():
|
||||
# hiden items
|
||||
|
||||
img_path = gr.Textbox(dir_name.rstrip("/"), visible=False)
|
||||
tabname_box = gr.Textbox(tabname, visible=False)
|
||||
image_index = gr.Textbox(value=-1, visible=False)
|
||||
set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False)
|
||||
filenames = gr.State()
|
||||
hidden = gr.Image(type="pil", visible=False)
|
||||
info1 = gr.Textbox(visible=False)
|
||||
info2 = gr.Textbox(visible=False)
|
||||
|
||||
# turn pages
|
||||
gallery_inputs = [img_path, page_index, image_index, tabname_box]
|
||||
gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name]
|
||||
|
||||
first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
# page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index])
|
||||
|
||||
# other funcitons
|
||||
set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden])
|
||||
img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
|
||||
delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
|
||||
hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
|
||||
|
||||
# pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
|
||||
switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
|
||||
switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
|
||||
|
||||
|
||||
def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
|
||||
with gr.Blocks(analytics_enabled=False) as images_history:
|
||||
with gr.Tabs() as tabs:
|
||||
with gr.Tab("txt2img history"):
|
||||
with gr.Blocks(analytics_enabled=False) as images_history_txt2img:
|
||||
show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict)
|
||||
with gr.Tab("img2img history"):
|
||||
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
|
||||
show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict)
|
||||
with gr.Tab("extras history"):
|
||||
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
|
||||
show_images_history(gr, opts, "extras", run_pnginfo, switch_dict)
|
||||
return images_history
|
||||
@ -0,0 +1,132 @@
|
||||
import sys
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
import inspect
|
||||
|
||||
|
||||
def report_exception(c, job):
|
||||
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
class ImageSaveParams:
|
||||
def __init__(self, image, p, filename, pnginfo):
|
||||
self.image = image
|
||||
"""the PIL image itself"""
|
||||
|
||||
self.p = p
|
||||
"""p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
|
||||
|
||||
self.filename = filename
|
||||
"""name of file that the image would be saved to"""
|
||||
|
||||
self.pnginfo = pnginfo
|
||||
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
|
||||
|
||||
|
||||
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
||||
callbacks_model_loaded = []
|
||||
callbacks_ui_tabs = []
|
||||
callbacks_ui_settings = []
|
||||
callbacks_before_image_saved = []
|
||||
callbacks_image_saved = []
|
||||
|
||||
|
||||
def clear_callbacks():
|
||||
callbacks_model_loaded.clear()
|
||||
callbacks_ui_tabs.clear()
|
||||
callbacks_ui_settings.clear()
|
||||
callbacks_before_image_saved.clear()
|
||||
callbacks_image_saved.clear()
|
||||
|
||||
|
||||
def model_loaded_callback(sd_model):
|
||||
for c in callbacks_model_loaded:
|
||||
try:
|
||||
c.callback(sd_model)
|
||||
except Exception:
|
||||
report_exception(c, 'model_loaded_callback')
|
||||
|
||||
|
||||
def ui_tabs_callback():
|
||||
res = []
|
||||
|
||||
for c in callbacks_ui_tabs:
|
||||
try:
|
||||
res += c.callback() or []
|
||||
except Exception:
|
||||
report_exception(c, 'ui_tabs_callback')
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def ui_settings_callback():
|
||||
for c in callbacks_ui_settings:
|
||||
try:
|
||||
c.callback()
|
||||
except Exception:
|
||||
report_exception(c, 'ui_settings_callback')
|
||||
|
||||
|
||||
def before_image_saved_callback(params: ImageSaveParams):
|
||||
for c in callbacks_image_saved:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'before_image_saved_callback')
|
||||
|
||||
|
||||
def image_saved_callback(params: ImageSaveParams):
|
||||
for c in callbacks_image_saved:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'image_saved_callback')
|
||||
|
||||
|
||||
def add_callback(callbacks, fun):
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||
|
||||
callbacks.append(ScriptCallback(filename, fun))
|
||||
|
||||
|
||||
def on_model_loaded(callback):
|
||||
"""register a function to be called when the stable diffusion model is created; the model is
|
||||
passed as an argument"""
|
||||
add_callback(callbacks_model_loaded, callback)
|
||||
|
||||
|
||||
def on_ui_tabs(callback):
|
||||
"""register a function to be called when the UI is creating new tabs.
|
||||
The function must either return a None, which means no new tabs to be added, or a list, where
|
||||
each element is a tuple:
|
||||
(gradio_component, title, elem_id)
|
||||
|
||||
gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
|
||||
title is tab text displayed to user in the UI
|
||||
elem_id is HTML id for the tab
|
||||
"""
|
||||
add_callback(callbacks_ui_tabs, callback)
|
||||
|
||||
|
||||
def on_ui_settings(callback):
|
||||
"""register a function to be called before UI settings are populated; add your settings
|
||||
by using shared.opts.add_option(shared.OptionInfo(...)) """
|
||||
add_callback(callbacks_ui_settings, callback)
|
||||
|
||||
|
||||
def on_before_image_saved(callback):
|
||||
"""register a function to be called before an image is saved to a file.
|
||||
The callback is called with one argument:
|
||||
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
|
||||
"""
|
||||
add_callback(callbacks_before_image_saved, callback)
|
||||
|
||||
|
||||
def on_image_saved(callback):
|
||||
"""register a function to be called after an image is saved to a file.
|
||||
The callback is called with one argument:
|
||||
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
|
||||
"""
|
||||
add_callback(callbacks_image_saved, callback)
|
||||
@ -0,0 +1,341 @@
|
||||
import cv2
|
||||
import requests
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from math import log, sqrt
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
GREEN = "#0F0"
|
||||
BLUE = "#00F"
|
||||
RED = "#F00"
|
||||
|
||||
|
||||
def crop_image(im, settings):
|
||||
""" Intelligently crop an image to the subject matter """
|
||||
|
||||
scale_by = 1
|
||||
if is_landscape(im.width, im.height):
|
||||
scale_by = settings.crop_height / im.height
|
||||
elif is_portrait(im.width, im.height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_square(im.width, im.height):
|
||||
if is_square(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_landscape(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_portrait(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_height / im.height
|
||||
|
||||
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
|
||||
im_debug = im.copy()
|
||||
|
||||
focus = focal_point(im_debug, settings)
|
||||
|
||||
# take the focal point and turn it into crop coordinates that try to center over the focal
|
||||
# point but then get adjusted back into the frame
|
||||
y_half = int(settings.crop_height / 2)
|
||||
x_half = int(settings.crop_width / 2)
|
||||
|
||||
x1 = focus.x - x_half
|
||||
if x1 < 0:
|
||||
x1 = 0
|
||||
elif x1 + settings.crop_width > im.width:
|
||||
x1 = im.width - settings.crop_width
|
||||
|
||||
y1 = focus.y - y_half
|
||||
if y1 < 0:
|
||||
y1 = 0
|
||||
elif y1 + settings.crop_height > im.height:
|
||||
y1 = im.height - settings.crop_height
|
||||
|
||||
x2 = x1 + settings.crop_width
|
||||
y2 = y1 + settings.crop_height
|
||||
|
||||
crop = [x1, y1, x2, y2]
|
||||
|
||||
results = []
|
||||
|
||||
results.append(im.crop(tuple(crop)))
|
||||
|
||||
if settings.annotate_image:
|
||||
d = ImageDraw.Draw(im_debug)
|
||||
rect = list(crop)
|
||||
rect[2] -= 1
|
||||
rect[3] -= 1
|
||||
d.rectangle(rect, outline=GREEN)
|
||||
results.append(im_debug)
|
||||
if settings.destop_view_image:
|
||||
im_debug.show()
|
||||
|
||||
return results
|
||||
|
||||
def focal_point(im, settings):
|
||||
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
|
||||
entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
|
||||
face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []
|
||||
|
||||
pois = []
|
||||
|
||||
weight_pref_total = 0
|
||||
if len(corner_points) > 0:
|
||||
weight_pref_total += settings.corner_points_weight
|
||||
if len(entropy_points) > 0:
|
||||
weight_pref_total += settings.entropy_points_weight
|
||||
if len(face_points) > 0:
|
||||
weight_pref_total += settings.face_points_weight
|
||||
|
||||
corner_centroid = None
|
||||
if len(corner_points) > 0:
|
||||
corner_centroid = centroid(corner_points)
|
||||
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
|
||||
pois.append(corner_centroid)
|
||||
|
||||
entropy_centroid = None
|
||||
if len(entropy_points) > 0:
|
||||
entropy_centroid = centroid(entropy_points)
|
||||
entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
|
||||
pois.append(entropy_centroid)
|
||||
|
||||
face_centroid = None
|
||||
if len(face_points) > 0:
|
||||
face_centroid = centroid(face_points)
|
||||
face_centroid.weight = settings.face_points_weight / weight_pref_total
|
||||
pois.append(face_centroid)
|
||||
|
||||
average_point = poi_average(pois, settings)
|
||||
|
||||
if settings.annotate_image:
|
||||
d = ImageDraw.Draw(im)
|
||||
max_size = min(im.width, im.height) * 0.07
|
||||
if corner_centroid is not None:
|
||||
color = BLUE
|
||||
box = corner_centroid.bounding(max_size * corner_centroid.weight)
|
||||
d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color)
|
||||
d.ellipse(box, outline=color)
|
||||
if len(corner_points) > 1:
|
||||
for f in corner_points:
|
||||
d.rectangle(f.bounding(4), outline=color)
|
||||
if entropy_centroid is not None:
|
||||
color = "#ff0"
|
||||
box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
|
||||
d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color)
|
||||
d.ellipse(box, outline=color)
|
||||
if len(entropy_points) > 1:
|
||||
for f in entropy_points:
|
||||
d.rectangle(f.bounding(4), outline=color)
|
||||
if face_centroid is not None:
|
||||
color = RED
|
||||
box = face_centroid.bounding(max_size * face_centroid.weight)
|
||||
d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color)
|
||||
d.ellipse(box, outline=color)
|
||||
if len(face_points) > 1:
|
||||
for f in face_points:
|
||||
d.rectangle(f.bounding(4), outline=color)
|
||||
|
||||
d.ellipse(average_point.bounding(max_size), outline=GREEN)
|
||||
|
||||
return average_point
|
||||
|
||||
|
||||
def image_face_points(im, settings):
|
||||
if settings.dnn_model_path is not None:
|
||||
detector = cv2.FaceDetectorYN.create(
|
||||
settings.dnn_model_path,
|
||||
"",
|
||||
(im.width, im.height),
|
||||
0.9, # score threshold
|
||||
0.3, # nms threshold
|
||||
5000 # keep top k before nms
|
||||
)
|
||||
faces = detector.detect(np.array(im))
|
||||
results = []
|
||||
if faces[1] is not None:
|
||||
for face in faces[1]:
|
||||
x = face[0]
|
||||
y = face[1]
|
||||
w = face[2]
|
||||
h = face[3]
|
||||
results.append(
|
||||
PointOfInterest(
|
||||
int(x + (w * 0.5)), # face focus left/right is center
|
||||
int(y + (h * 0.33)), # face focus up/down is close to the top of the head
|
||||
size = w,
|
||||
weight = 1/len(faces[1])
|
||||
)
|
||||
)
|
||||
return results
|
||||
else:
|
||||
np_im = np.array(im)
|
||||
gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
tries = [
|
||||
[ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
|
||||
[ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
|
||||
[ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
|
||||
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
|
||||
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
|
||||
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
|
||||
[ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
|
||||
[ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
|
||||
]
|
||||
for t in tries:
|
||||
classifier = cv2.CascadeClassifier(t[0])
|
||||
minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
|
||||
try:
|
||||
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
|
||||
minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
|
||||
except:
|
||||
continue
|
||||
|
||||
if len(faces) > 0:
|
||||
rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
|
||||
return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
|
||||
return []
|
||||
|
||||
|
||||
def image_corner_points(im, settings):
|
||||
grayscale = im.convert("L")
|
||||
|
||||
# naive attempt at preventing focal points from collecting at watermarks near the bottom
|
||||
gd = ImageDraw.Draw(grayscale)
|
||||
gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
|
||||
|
||||
np_im = np.array(grayscale)
|
||||
|
||||
points = cv2.goodFeaturesToTrack(
|
||||
np_im,
|
||||
maxCorners=100,
|
||||
qualityLevel=0.04,
|
||||
minDistance=min(grayscale.width, grayscale.height)*0.06,
|
||||
useHarrisDetector=False,
|
||||
)
|
||||
|
||||
if points is None:
|
||||
return []
|
||||
|
||||
focal_points = []
|
||||
for point in points:
|
||||
x, y = point.ravel()
|
||||
focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points)))
|
||||
|
||||
return focal_points
|
||||
|
||||
|
||||
def image_entropy_points(im, settings):
|
||||
landscape = im.height < im.width
|
||||
portrait = im.height > im.width
|
||||
if landscape:
|
||||
move_idx = [0, 2]
|
||||
move_max = im.size[0]
|
||||
elif portrait:
|
||||
move_idx = [1, 3]
|
||||
move_max = im.size[1]
|
||||
else:
|
||||
return []
|
||||
|
||||
e_max = 0
|
||||
crop_current = [0, 0, settings.crop_width, settings.crop_height]
|
||||
crop_best = crop_current
|
||||
while crop_current[move_idx[1]] < move_max:
|
||||
crop = im.crop(tuple(crop_current))
|
||||
e = image_entropy(crop)
|
||||
|
||||
if (e > e_max):
|
||||
e_max = e
|
||||
crop_best = list(crop_current)
|
||||
|
||||
crop_current[move_idx[0]] += 4
|
||||
crop_current[move_idx[1]] += 4
|
||||
|
||||
x_mid = int(crop_best[0] + settings.crop_width/2)
|
||||
y_mid = int(crop_best[1] + settings.crop_height/2)
|
||||
|
||||
return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
|
||||
|
||||
|
||||
def image_entropy(im):
|
||||
# greyscale image entropy
|
||||
# band = np.asarray(im.convert("L"))
|
||||
band = np.asarray(im.convert("1"), dtype=np.uint8)
|
||||
hist, _ = np.histogram(band, bins=range(0, 256))
|
||||
hist = hist[hist > 0]
|
||||
return -np.log2(hist / hist.sum()).sum()
|
||||
|
||||
def centroid(pois):
|
||||
x = [poi.x for poi in pois]
|
||||
y = [poi.y for poi in pois]
|
||||
return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
|
||||
|
||||
|
||||
def poi_average(pois, settings):
|
||||
weight = 0.0
|
||||
x = 0.0
|
||||
y = 0.0
|
||||
for poi in pois:
|
||||
weight += poi.weight
|
||||
x += poi.x * poi.weight
|
||||
y += poi.y * poi.weight
|
||||
avg_x = round(x / weight)
|
||||
avg_y = round(y / weight)
|
||||
|
||||
return PointOfInterest(avg_x, avg_y)
|
||||
|
||||
|
||||
def is_landscape(w, h):
|
||||
return w > h
|
||||
|
||||
|
||||
def is_portrait(w, h):
|
||||
return h > w
|
||||
|
||||
|
||||
def is_square(w, h):
|
||||
return w == h
|
||||
|
||||
|
||||
def download_and_cache_models(dirname):
|
||||
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
||||
model_file_name = 'face_detection_yunet.onnx'
|
||||
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
cache_file = os.path.join(dirname, model_file_name)
|
||||
if not os.path.exists(cache_file):
|
||||
print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
|
||||
response = requests.get(download_url)
|
||||
with open(cache_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
if os.path.exists(cache_file):
|
||||
return cache_file
|
||||
return None
|
||||
|
||||
|
||||
class PointOfInterest:
|
||||
def __init__(self, x, y, weight=1.0, size=10):
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.weight = weight
|
||||
self.size = size
|
||||
|
||||
def bounding(self, size):
|
||||
return [
|
||||
self.x - size//2,
|
||||
self.y - size//2,
|
||||
self.x + size//2,
|
||||
self.y + size//2
|
||||
]
|
||||
|
||||
|
||||
class Settings:
|
||||
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
|
||||
self.crop_width = crop_width
|
||||
self.crop_height = crop_height
|
||||
self.corner_points_weight = corner_points_weight
|
||||
self.entropy_points_weight = entropy_points_weight
|
||||
self.face_points_weight = face_points_weight
|
||||
self.annotate_image = annotate_image
|
||||
self.destop_view_image = False
|
||||
self.dnn_model_path = dnn_model_path
|
||||
@ -0,0 +1,172 @@
|
||||
import json
|
||||
import os.path
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import git
|
||||
|
||||
import gradio as gr
|
||||
import html
|
||||
|
||||
from modules import extensions, shared, paths
|
||||
|
||||
|
||||
def check_access():
|
||||
assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags"
|
||||
|
||||
|
||||
def apply_and_restart(disable_list, update_list):
|
||||
check_access()
|
||||
|
||||
disabled = json.loads(disable_list)
|
||||
assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
|
||||
|
||||
update = json.loads(update_list)
|
||||
assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
|
||||
|
||||
update = set(update)
|
||||
|
||||
for ext in extensions.extensions:
|
||||
if ext.name not in update:
|
||||
continue
|
||||
|
||||
try:
|
||||
ext.pull()
|
||||
except Exception:
|
||||
print(f"Error pulling updates for {ext.name}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
shared.opts.disabled_extensions = disabled
|
||||
shared.opts.save(shared.config_filename)
|
||||
|
||||
shared.state.interrupt()
|
||||
shared.state.need_restart = True
|
||||
|
||||
|
||||
def check_updates():
|
||||
check_access()
|
||||
|
||||
for ext in extensions.extensions:
|
||||
if ext.remote is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
ext.check_updates()
|
||||
except Exception:
|
||||
print(f"Error checking updates for {ext.name}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
return extension_table()
|
||||
|
||||
|
||||
def extension_table():
|
||||
code = f"""<!-- {time.time()} -->
|
||||
<table id="extensions">
|
||||
<thead>
|
||||
<tr>
|
||||
<th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
|
||||
<th>URL</th>
|
||||
<th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
"""
|
||||
|
||||
for ext in extensions.extensions:
|
||||
if ext.can_update:
|
||||
ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
|
||||
else:
|
||||
ext_status = ext.status
|
||||
|
||||
code += f"""
|
||||
<tr>
|
||||
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
|
||||
<td><a href="{html.escape(ext.remote or '')}">{html.escape(ext.remote or '')}</a></td>
|
||||
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
|
||||
</tr>
|
||||
"""
|
||||
|
||||
code += """
|
||||
</tbody>
|
||||
</table>
|
||||
"""
|
||||
|
||||
return code
|
||||
|
||||
|
||||
def install_extension_from_url(dirname, url):
|
||||
check_access()
|
||||
|
||||
assert url, 'No URL specified'
|
||||
|
||||
if dirname is None or dirname == "":
|
||||
*parts, last_part = url.split('/')
|
||||
last_part = last_part.replace(".git", "")
|
||||
|
||||
dirname = last_part
|
||||
|
||||
target_dir = os.path.join(extensions.extensions_dir, dirname)
|
||||
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
|
||||
|
||||
assert len([x for x in extensions.extensions if x.remote == url]) == 0, 'Extension with this URL is already installed'
|
||||
|
||||
tmpdir = os.path.join(paths.script_path, "tmp", dirname)
|
||||
|
||||
try:
|
||||
shutil.rmtree(tmpdir, True)
|
||||
|
||||
repo = git.Repo.clone_from(url, tmpdir)
|
||||
repo.remote().fetch()
|
||||
|
||||
os.rename(tmpdir, target_dir)
|
||||
|
||||
extensions.list_extensions()
|
||||
return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
|
||||
finally:
|
||||
shutil.rmtree(tmpdir, True)
|
||||
|
||||
|
||||
def create_ui():
|
||||
import modules.ui
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as ui:
|
||||
with gr.Tabs(elem_id="tabs_extensions") as tabs:
|
||||
with gr.TabItem("Installed"):
|
||||
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False)
|
||||
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False)
|
||||
|
||||
with gr.Row():
|
||||
apply = gr.Button(value="Apply and restart UI", variant="primary")
|
||||
check = gr.Button(value="Check for updates")
|
||||
|
||||
extensions_table = gr.HTML(lambda: extension_table())
|
||||
|
||||
apply.click(
|
||||
fn=apply_and_restart,
|
||||
_js="extensions_apply",
|
||||
inputs=[extensions_disabled_list, extensions_update_list],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
check.click(
|
||||
fn=check_updates,
|
||||
_js="extensions_check",
|
||||
inputs=[],
|
||||
outputs=[extensions_table],
|
||||
)
|
||||
|
||||
with gr.TabItem("Install from URL"):
|
||||
install_url = gr.Text(label="URL for extension's git repository")
|
||||
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
|
||||
intall_button = gr.Button(value="Install", variant="primary")
|
||||
intall_result = gr.HTML(elem_id="extension_install_result")
|
||||
|
||||
intall_button.click(
|
||||
fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
|
||||
inputs=[install_dirname, install_url],
|
||||
outputs=[extensions_table, intall_result],
|
||||
)
|
||||
|
||||
return ui
|
||||
@ -0,0 +1,29 @@
|
||||
import unittest
|
||||
|
||||
|
||||
class TestExtrasWorking(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.url_img2img = "http://localhost:7860/sdapi/v1/extra-single-image"
|
||||
self.simple_extras = {
|
||||
"resize_mode": 0,
|
||||
"show_extras_results": True,
|
||||
"gfpgan_visibility": 0,
|
||||
"codeformer_visibility": 0,
|
||||
"codeformer_weight": 0,
|
||||
"upscaling_resize": 2,
|
||||
"upscaling_resize_w": 512,
|
||||
"upscaling_resize_h": 512,
|
||||
"upscaling_crop": True,
|
||||
"upscaler_1": "None",
|
||||
"upscaler_2": "None",
|
||||
"extras_upscaler_2_visibility": 0,
|
||||
"image": ""
|
||||
}
|
||||
|
||||
|
||||
class TestExtrasCorrectness(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -0,0 +1,59 @@
|
||||
import unittest
|
||||
import requests
|
||||
from gradio.processing_utils import encode_pil_to_base64
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class TestImg2ImgWorking(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.url_img2img = "http://localhost:7860/sdapi/v1/img2img"
|
||||
self.simple_img2img = {
|
||||
"init_images": [encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))],
|
||||
"resize_mode": 0,
|
||||
"denoising_strength": 0.75,
|
||||
"mask": None,
|
||||
"mask_blur": 4,
|
||||
"inpainting_fill": 0,
|
||||
"inpaint_full_res": False,
|
||||
"inpaint_full_res_padding": 0,
|
||||
"inpainting_mask_invert": 0,
|
||||
"prompt": "example prompt",
|
||||
"styles": [],
|
||||
"seed": -1,
|
||||
"subseed": -1,
|
||||
"subseed_strength": 0,
|
||||
"seed_resize_from_h": -1,
|
||||
"seed_resize_from_w": -1,
|
||||
"batch_size": 1,
|
||||
"n_iter": 1,
|
||||
"steps": 3,
|
||||
"cfg_scale": 7,
|
||||
"width": 64,
|
||||
"height": 64,
|
||||
"restore_faces": False,
|
||||
"tiling": False,
|
||||
"negative_prompt": "",
|
||||
"eta": 0,
|
||||
"s_churn": 0,
|
||||
"s_tmax": 0,
|
||||
"s_tmin": 0,
|
||||
"s_noise": 1,
|
||||
"override_settings": {},
|
||||
"sampler_index": "Euler a",
|
||||
"include_init_images": False
|
||||
}
|
||||
|
||||
def test_img2img_simple_performed(self):
|
||||
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
|
||||
|
||||
def test_inpainting_masked_performed(self):
|
||||
self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png"))
|
||||
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
|
||||
|
||||
|
||||
class TestImg2ImgCorrectness(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -0,0 +1,19 @@
|
||||
import unittest
|
||||
import requests
|
||||
import time
|
||||
|
||||
|
||||
def run_tests():
|
||||
timeout_threshold = 240
|
||||
start_time = time.time()
|
||||
while time.time()-start_time < timeout_threshold:
|
||||
try:
|
||||
requests.head("http://localhost:7860/")
|
||||
break
|
||||
except requests.exceptions.ConnectionError:
|
||||
pass
|
||||
if time.time()-start_time < timeout_threshold:
|
||||
suite = unittest.TestLoader().discover('', pattern='*_test.py')
|
||||
result = unittest.TextTestRunner(verbosity=2).run(suite)
|
||||
else:
|
||||
print("Launch unsuccessful")
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 9.7 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 362 B |
@ -0,0 +1,74 @@
|
||||
import unittest
|
||||
import requests
|
||||
|
||||
|
||||
class TestTxt2ImgWorking(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img"
|
||||
self.simple_txt2img = {
|
||||
"enable_hr": False,
|
||||
"denoising_strength": 0,
|
||||
"firstphase_width": 0,
|
||||
"firstphase_height": 0,
|
||||
"prompt": "example prompt",
|
||||
"styles": [],
|
||||
"seed": -1,
|
||||
"subseed": -1,
|
||||
"subseed_strength": 0,
|
||||
"seed_resize_from_h": -1,
|
||||
"seed_resize_from_w": -1,
|
||||
"batch_size": 1,
|
||||
"n_iter": 1,
|
||||
"steps": 3,
|
||||
"cfg_scale": 7,
|
||||
"width": 64,
|
||||
"height": 64,
|
||||
"restore_faces": False,
|
||||
"tiling": False,
|
||||
"negative_prompt": "",
|
||||
"eta": 0,
|
||||
"s_churn": 0,
|
||||
"s_tmax": 0,
|
||||
"s_tmin": 0,
|
||||
"s_noise": 1,
|
||||
"sampler_index": "Euler a"
|
||||
}
|
||||
|
||||
def test_txt2img_simple_performed(self):
|
||||
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
|
||||
|
||||
def test_txt2img_with_negative_prompt_performed(self):
|
||||
self.simple_txt2img["negative_prompt"] = "example negative prompt"
|
||||
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
|
||||
|
||||
def test_txt2img_not_square_image_performed(self):
|
||||
self.simple_txt2img["height"] = 128
|
||||
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
|
||||
|
||||
def test_txt2img_with_hrfix_performed(self):
|
||||
self.simple_txt2img["enable_hr"] = True
|
||||
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
|
||||
|
||||
def test_txt2img_with_restore_faces_performed(self):
|
||||
self.simple_txt2img["restore_faces"] = True
|
||||
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
|
||||
|
||||
def test_txt2img_with_tiling_faces_performed(self):
|
||||
self.simple_txt2img["tiling"] = True
|
||||
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
|
||||
|
||||
def test_txt2img_with_vanilla_sampler_performed(self):
|
||||
self.simple_txt2img["sampler_index"] = "PLMS"
|
||||
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
|
||||
|
||||
def test_txt2img_multiple_batches_performed(self):
|
||||
self.simple_txt2img["n_iter"] = 2
|
||||
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
|
||||
|
||||
|
||||
class TestTxt2ImgCorrectness(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue