Merge branch 'master' into tensorboard
commit
9cd7716753
@ -0,0 +1,5 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: WebUI Community Support
|
||||
url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions
|
||||
about: Please ask and answer questions here.
|
||||
@ -0,0 +1,29 @@
|
||||
name: Run basic features tests on CPU with empty SD model
|
||||
|
||||
on:
|
||||
- push
|
||||
- pull_request
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v3
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.10.6
|
||||
cache: pip
|
||||
cache-dependency-path: |
|
||||
**/requirements*txt
|
||||
- name: Run tests
|
||||
run: python launch.py --tests --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
|
||||
- name: Upload main app stdout-stderr
|
||||
uses: actions/upload-artifact@v3
|
||||
if: always()
|
||||
with:
|
||||
name: stdout-stderr
|
||||
path: |
|
||||
test/stdout.txt
|
||||
test/stderr.txt
|
||||
@ -1 +1,12 @@
|
||||
* @AUTOMATIC1111
|
||||
|
||||
# if you were managing a localization and were removed from this file, this is because
|
||||
# the intended way to do localizations now is via extensions. See:
|
||||
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions
|
||||
# Make a repo with your localization and since you are still listed as a collaborator
|
||||
# you can add it to the wiki page yourself. This change is because some people complained
|
||||
# the git commit log is cluttered with things unrelated to almost everyone and
|
||||
# because I believe this is the best overall for the project to handle localizations almost
|
||||
# entirely without my oversight.
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,72 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: modules.xlmr.BertSeriesModelWithTransformation
|
||||
params:
|
||||
name: "XLMR-Large"
|
||||
@ -0,0 +1,70 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
@ -0,0 +1,6 @@
|
||||
import os
|
||||
from modules import paths
|
||||
|
||||
|
||||
def preload(parser):
|
||||
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
|
||||
@ -0,0 +1,286 @@
|
||||
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
|
||||
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
|
||||
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
|
||||
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
import torch.nn.functional as F
|
||||
from contextlib import contextmanager
|
||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
import ldm.models.autoencoder
|
||||
|
||||
class VQModel(pl.LightningModule):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
batch_resize_range=None,
|
||||
scheduler_config=None,
|
||||
lr_g_factor=1.0,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
use_ema=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.n_embed = n_embed
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
||||
remap=remap,
|
||||
sane_index_shape=sane_index_shape)
|
||||
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels)==int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
self.batch_resize_range = batch_resize_range
|
||||
if self.batch_resize_range is not None:
|
||||
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
||||
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self)
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
self.scheduler_config = scheduler_config
|
||||
self.lr_g_factor = lr_g_factor
|
||||
|
||||
@contextmanager
|
||||
def ema_scope(self, context=None):
|
||||
if self.use_ema:
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
return quant, emb_loss, info
|
||||
|
||||
def encode_to_prequant(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, quant):
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
def decode_code(self, code_b):
|
||||
quant_b = self.quantize.embed_code(code_b)
|
||||
dec = self.decode(quant_b)
|
||||
return dec
|
||||
|
||||
def forward(self, input, return_pred_indices=False):
|
||||
quant, diff, (_,_,ind) = self.encode(input)
|
||||
dec = self.decode(quant)
|
||||
if return_pred_indices:
|
||||
return dec, diff, ind
|
||||
return dec, diff
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
if self.batch_resize_range is not None:
|
||||
lower_size = self.batch_resize_range[0]
|
||||
upper_size = self.batch_resize_range[1]
|
||||
if self.global_step <= 4:
|
||||
# do the first few batches with max size to avoid later oom
|
||||
new_resize = upper_size
|
||||
else:
|
||||
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
||||
if new_resize != x.shape[2]:
|
||||
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
||||
x = x.detach()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
# https://github.com/pytorch/pytorch/issues/37142
|
||||
# try not to fool the heuristics
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train",
|
||||
predicted_indices=ind)
|
||||
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, suffix=""):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val"+suffix,
|
||||
predicted_indices=ind
|
||||
)
|
||||
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val"+suffix,
|
||||
predicted_indices=ind
|
||||
)
|
||||
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log(f"val{suffix}/rec_loss", rec_loss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
self.log(f"val{suffix}/aeloss", aeloss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||
del log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr_d = self.learning_rate
|
||||
lr_g = self.lr_g_factor*self.learning_rate
|
||||
print("lr_d", lr_d)
|
||||
print("lr_g", lr_g)
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quantize.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr_g, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr_d, betas=(0.5, 0.9))
|
||||
|
||||
if self.scheduler_config is not None:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
scheduler = [
|
||||
{
|
||||
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
},
|
||||
{
|
||||
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
},
|
||||
]
|
||||
return [opt_ae, opt_disc], scheduler
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if only_inputs:
|
||||
log["inputs"] = x
|
||||
return log
|
||||
xrec, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = xrec
|
||||
if plot_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, _ = self(x)
|
||||
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
||||
log["reconstructions_ema"] = xrec_ema
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class VQModelInterface(VQModel):
|
||||
def __init__(self, embed_dim, *args, **kwargs):
|
||||
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, h, force_not_quantize=False):
|
||||
# also go through quantization layer
|
||||
if not force_not_quantize:
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
else:
|
||||
quant = h
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
setattr(ldm.models.autoencoder, "VQModel", VQModel)
|
||||
setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,6 @@
|
||||
import os
|
||||
from modules import paths
|
||||
|
||||
|
||||
def preload(parser):
|
||||
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
|
||||
@ -0,0 +1,6 @@
|
||||
import os
|
||||
from modules import paths
|
||||
|
||||
|
||||
def preload(parser):
|
||||
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
|
||||
@ -0,0 +1,107 @@
|
||||
// Stable Diffusion WebUI - Bracket checker
|
||||
// Version 1.0
|
||||
// By Hingashi no Florin/Bwin4L
|
||||
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
|
||||
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
|
||||
|
||||
function checkBrackets(evt) {
|
||||
textArea = evt.target;
|
||||
tabName = evt.target.parentElement.parentElement.id.split("_")[0];
|
||||
counterElt = document.querySelector('gradio-app').shadowRoot.querySelector('#' + tabName + '_token_counter');
|
||||
|
||||
promptName = evt.target.parentElement.parentElement.id.includes('neg') ? ' negative' : '';
|
||||
|
||||
errorStringParen = '(' + tabName + promptName + ' prompt) - Different number of opening and closing parentheses detected.\n';
|
||||
errorStringSquare = '[' + tabName + promptName + ' prompt] - Different number of opening and closing square brackets detected.\n';
|
||||
errorStringCurly = '{' + tabName + promptName + ' prompt} - Different number of opening and closing curly brackets detected.\n';
|
||||
|
||||
openBracketRegExp = /\(/g;
|
||||
closeBracketRegExp = /\)/g;
|
||||
|
||||
openSquareBracketRegExp = /\[/g;
|
||||
closeSquareBracketRegExp = /\]/g;
|
||||
|
||||
openCurlyBracketRegExp = /\{/g;
|
||||
closeCurlyBracketRegExp = /\}/g;
|
||||
|
||||
totalOpenBracketMatches = 0;
|
||||
totalCloseBracketMatches = 0;
|
||||
totalOpenSquareBracketMatches = 0;
|
||||
totalCloseSquareBracketMatches = 0;
|
||||
totalOpenCurlyBracketMatches = 0;
|
||||
totalCloseCurlyBracketMatches = 0;
|
||||
|
||||
openBracketMatches = textArea.value.match(openBracketRegExp);
|
||||
if(openBracketMatches) {
|
||||
totalOpenBracketMatches = openBracketMatches.length;
|
||||
}
|
||||
|
||||
closeBracketMatches = textArea.value.match(closeBracketRegExp);
|
||||
if(closeBracketMatches) {
|
||||
totalCloseBracketMatches = closeBracketMatches.length;
|
||||
}
|
||||
|
||||
openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp);
|
||||
if(openSquareBracketMatches) {
|
||||
totalOpenSquareBracketMatches = openSquareBracketMatches.length;
|
||||
}
|
||||
|
||||
closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp);
|
||||
if(closeSquareBracketMatches) {
|
||||
totalCloseSquareBracketMatches = closeSquareBracketMatches.length;
|
||||
}
|
||||
|
||||
openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp);
|
||||
if(openCurlyBracketMatches) {
|
||||
totalOpenCurlyBracketMatches = openCurlyBracketMatches.length;
|
||||
}
|
||||
|
||||
closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp);
|
||||
if(closeCurlyBracketMatches) {
|
||||
totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length;
|
||||
}
|
||||
|
||||
if(totalOpenBracketMatches != totalCloseBracketMatches) {
|
||||
if(!counterElt.title.includes(errorStringParen)) {
|
||||
counterElt.title += errorStringParen;
|
||||
}
|
||||
} else {
|
||||
counterElt.title = counterElt.title.replace(errorStringParen, '');
|
||||
}
|
||||
|
||||
if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) {
|
||||
if(!counterElt.title.includes(errorStringSquare)) {
|
||||
counterElt.title += errorStringSquare;
|
||||
}
|
||||
} else {
|
||||
counterElt.title = counterElt.title.replace(errorStringSquare, '');
|
||||
}
|
||||
|
||||
if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) {
|
||||
if(!counterElt.title.includes(errorStringCurly)) {
|
||||
counterElt.title += errorStringCurly;
|
||||
}
|
||||
} else {
|
||||
counterElt.title = counterElt.title.replace(errorStringCurly, '');
|
||||
}
|
||||
|
||||
if(counterElt.title != '') {
|
||||
counterElt.style = 'color: #FF5555;';
|
||||
} else {
|
||||
counterElt.style = '';
|
||||
}
|
||||
}
|
||||
|
||||
var shadowRootLoaded = setInterval(function() {
|
||||
var shadowTextArea = document.querySelector('gradio-app').shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
|
||||
if(shadowTextArea.length < 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
clearInterval(shadowRootLoaded);
|
||||
|
||||
document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_prompt').onkeyup = checkBrackets;
|
||||
document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_neg_prompt').onkeyup = checkBrackets;
|
||||
document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_prompt').onkeyup = checkBrackets;
|
||||
document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_neg_prompt').onkeyup = checkBrackets;
|
||||
}, 1000);
|
||||
@ -0,0 +1,50 @@
|
||||
import random
|
||||
|
||||
from modules import script_callbacks, shared
|
||||
import gradio as gr
|
||||
|
||||
art_symbol = '\U0001f3a8' # 🎨
|
||||
global_prompt = None
|
||||
related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" }
|
||||
|
||||
|
||||
def roll_artist(prompt):
|
||||
allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories])
|
||||
artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
|
||||
|
||||
return prompt + ", " + artist.name if prompt != '' else artist.name
|
||||
|
||||
|
||||
def add_roll_button(prompt):
|
||||
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
||||
|
||||
roll.click(
|
||||
fn=roll_artist,
|
||||
_js="update_txt2img_tokens",
|
||||
inputs=[
|
||||
prompt,
|
||||
],
|
||||
outputs=[
|
||||
prompt,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def after_component(component, **kwargs):
|
||||
global global_prompt
|
||||
|
||||
elem_id = kwargs.get('elem_id', None)
|
||||
if elem_id not in related_ids:
|
||||
return
|
||||
|
||||
if elem_id == "txt2img_prompt":
|
||||
global_prompt = component
|
||||
elif elem_id == "txt2img_clear_prompt":
|
||||
add_roll_button(global_prompt)
|
||||
elif elem_id == "img2img_prompt":
|
||||
global_prompt = component
|
||||
elif elem_id == "img2img_clear_prompt":
|
||||
add_roll_button(global_prompt)
|
||||
|
||||
|
||||
script_callbacks.on_after_component(after_component)
|
||||
@ -0,0 +1,13 @@
|
||||
<div>
|
||||
<a href="/docs">API</a>
|
||||
•
|
||||
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a>
|
||||
•
|
||||
<a href="https://gradio.app">Gradio</a>
|
||||
•
|
||||
<a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a>
|
||||
</div>
|
||||
<br />
|
||||
<div class="versions">
|
||||
{versions}
|
||||
</div>
|
||||
@ -0,0 +1,419 @@
|
||||
<style>
|
||||
#licenses h2 {font-size: 1.2em; font-weight: bold; margin-bottom: 0.2em;}
|
||||
#licenses small {font-size: 0.95em; opacity: 0.85;}
|
||||
#licenses pre { margin: 1em 0 2em 0;}
|
||||
</style>
|
||||
|
||||
<h2><a href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">CodeFormer</a></h2>
|
||||
<small>Parts of CodeFormer code had to be copied to be compatible with GFPGAN.</small>
|
||||
<pre>
|
||||
S-Lab License 1.0
|
||||
|
||||
Copyright 2022 S-Lab
|
||||
|
||||
Redistribution and use for non-commercial purpose in source and
|
||||
binary forms, with or without modification, are permitted provided
|
||||
that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
In the event that redistribution and/or use for commercial purpose in
|
||||
source or binary forms, with or without modification is required,
|
||||
please contact the contributor(s) of the work.
|
||||
</pre>
|
||||
|
||||
|
||||
<h2><a href="https://github.com/victorca25/iNNfer/blob/main/LICENSE">ESRGAN</a></h2>
|
||||
<small>Code for architecture and reading models copied.</small>
|
||||
<pre>
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2021 victorca25
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE">Real-ESRGAN</a></h2>
|
||||
<small>Some code is copied to support ESRGAN models.</small>
|
||||
<pre>
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2021, Xintao Wang
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE">InvokeAI</a></h2>
|
||||
<small>Some code for compatibility with OSX is taken from lstein's repository.</small>
|
||||
<pre>
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 InvokeAI Team
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/Hafiidz/latent-diffusion/blob/main/LICENSE">LDSR</a></h2>
|
||||
<small>Code added by contirubtors, most likely copied from this repository.</small>
|
||||
<pre>
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/pharmapsychotic/clip-interrogator/blob/main/LICENSE">CLIP Interrogator</a></h2>
|
||||
<small>Some small amounts of code borrowed and reworked.</small>
|
||||
<pre>
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 pharmapsychotic
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
|
||||
<small>Code added by contributors, most likely copied from this repository.</small>
|
||||
|
||||
<pre>
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [2021] [SwinIR Authors]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/AminRezaei0x443/memory-efficient-attention/blob/main/LICENSE">Memory Efficient Attention</a></h2>
|
||||
<small>The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.</small>
|
||||
<pre>
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Alex Birch
|
||||
Copyright (c) 2023 Amin Rezaei
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
</pre>
|
||||
|
||||
@ -0,0 +1,35 @@
|
||||
|
||||
function extensions_apply(_, _){
|
||||
disable = []
|
||||
update = []
|
||||
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
||||
if(x.name.startsWith("enable_") && ! x.checked)
|
||||
disable.push(x.name.substr(7))
|
||||
|
||||
if(x.name.startsWith("update_") && x.checked)
|
||||
update.push(x.name.substr(7))
|
||||
})
|
||||
|
||||
restart_reload()
|
||||
|
||||
return [JSON.stringify(disable), JSON.stringify(update)]
|
||||
}
|
||||
|
||||
function extensions_check(){
|
||||
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
||||
x.innerHTML = "Loading..."
|
||||
})
|
||||
|
||||
return []
|
||||
}
|
||||
|
||||
function install_extension_from_index(button, url){
|
||||
button.disabled = "disabled"
|
||||
button.value = "Installing..."
|
||||
|
||||
textarea = gradioApp().querySelector('#extension_to_install textarea')
|
||||
textarea.value = url
|
||||
textarea.dispatchEvent(new Event("input", { bubbles: true }))
|
||||
|
||||
gradioApp().querySelector('#install_extension_button').click()
|
||||
}
|
||||
@ -0,0 +1,33 @@
|
||||
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
|
||||
|
||||
let txt2img_gallery, img2img_gallery, modal = undefined;
|
||||
onUiUpdate(function(){
|
||||
if (!txt2img_gallery) {
|
||||
txt2img_gallery = attachGalleryListeners("txt2img")
|
||||
}
|
||||
if (!img2img_gallery) {
|
||||
img2img_gallery = attachGalleryListeners("img2img")
|
||||
}
|
||||
if (!modal) {
|
||||
modal = gradioApp().getElementById('lightboxModal')
|
||||
modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] });
|
||||
}
|
||||
});
|
||||
|
||||
let modalObserver = new MutationObserver(function(mutations) {
|
||||
mutations.forEach(function(mutationRecord) {
|
||||
let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText
|
||||
if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img')
|
||||
gradioApp().getElementById(selectedTab+"_generation_info_button").click()
|
||||
});
|
||||
});
|
||||
|
||||
function attachGalleryListeners(tab_name) {
|
||||
gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
|
||||
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
|
||||
gallery?.addEventListener('keydown', (e) => {
|
||||
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
|
||||
gradioApp().getElementById(tab_name+"_generation_info_button").click()
|
||||
});
|
||||
return gallery;
|
||||
}
|
||||
@ -0,0 +1,25 @@
|
||||
|
||||
function setInactive(elem, inactive){
|
||||
console.log(elem)
|
||||
if(inactive){
|
||||
elem.classList.add('inactive')
|
||||
} else{
|
||||
elem.classList.remove('inactive')
|
||||
}
|
||||
}
|
||||
|
||||
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
|
||||
console.log(enable, width, height, hr_scale, hr_resize_x, hr_resize_y)
|
||||
|
||||
hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
|
||||
hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
|
||||
hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
|
||||
|
||||
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
|
||||
|
||||
setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0)
|
||||
setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0)
|
||||
setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0)
|
||||
|
||||
return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y]
|
||||
}
|
||||
@ -1,206 +0,0 @@
|
||||
var images_history_click_image = function(){
|
||||
if (!this.classList.contains("transform")){
|
||||
var gallery = images_history_get_parent_by_class(this, "images_history_cantainor");
|
||||
var buttons = gallery.querySelectorAll(".gallery-item");
|
||||
var i = 0;
|
||||
var hidden_list = [];
|
||||
buttons.forEach(function(e){
|
||||
if (e.style.display == "none"){
|
||||
hidden_list.push(i);
|
||||
}
|
||||
i += 1;
|
||||
})
|
||||
if (hidden_list.length > 0){
|
||||
setTimeout(images_history_hide_buttons, 10, hidden_list, gallery);
|
||||
}
|
||||
}
|
||||
images_history_set_image_info(this);
|
||||
}
|
||||
|
||||
var images_history_click_tab = function(){
|
||||
var tabs_box = gradioApp().getElementById("images_history_tab");
|
||||
if (!tabs_box.classList.contains(this.getAttribute("tabname"))) {
|
||||
gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_renew_page").click();
|
||||
tabs_box.classList.add(this.getAttribute("tabname"))
|
||||
}
|
||||
}
|
||||
|
||||
function images_history_disabled_del(){
|
||||
gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
|
||||
btn.setAttribute('disabled','disabled');
|
||||
});
|
||||
}
|
||||
|
||||
function images_history_get_parent_by_class(item, class_name){
|
||||
var parent = item.parentElement;
|
||||
while(!parent.classList.contains(class_name)){
|
||||
parent = parent.parentElement;
|
||||
}
|
||||
return parent;
|
||||
}
|
||||
|
||||
function images_history_get_parent_by_tagname(item, tagname){
|
||||
var parent = item.parentElement;
|
||||
tagname = tagname.toUpperCase()
|
||||
while(parent.tagName != tagname){
|
||||
console.log(parent.tagName, tagname)
|
||||
parent = parent.parentElement;
|
||||
}
|
||||
return parent;
|
||||
}
|
||||
|
||||
function images_history_hide_buttons(hidden_list, gallery){
|
||||
var buttons = gallery.querySelectorAll(".gallery-item");
|
||||
var num = 0;
|
||||
buttons.forEach(function(e){
|
||||
if (e.style.display == "none"){
|
||||
num += 1;
|
||||
}
|
||||
});
|
||||
if (num == hidden_list.length){
|
||||
setTimeout(images_history_hide_buttons, 10, hidden_list, gallery);
|
||||
}
|
||||
for( i in hidden_list){
|
||||
buttons[hidden_list[i]].style.display = "none";
|
||||
}
|
||||
}
|
||||
|
||||
function images_history_set_image_info(button){
|
||||
var buttons = images_history_get_parent_by_tagname(button, "DIV").querySelectorAll(".gallery-item");
|
||||
var index = -1;
|
||||
var i = 0;
|
||||
buttons.forEach(function(e){
|
||||
if(e == button){
|
||||
index = i;
|
||||
}
|
||||
if(e.style.display != "none"){
|
||||
i += 1;
|
||||
}
|
||||
});
|
||||
var gallery = images_history_get_parent_by_class(button, "images_history_cantainor");
|
||||
var set_btn = gallery.querySelector(".images_history_set_index");
|
||||
var curr_idx = set_btn.getAttribute("img_index", index);
|
||||
if (curr_idx != index) {
|
||||
set_btn.setAttribute("img_index", index);
|
||||
images_history_disabled_del();
|
||||
}
|
||||
set_btn.click();
|
||||
|
||||
}
|
||||
|
||||
function images_history_get_current_img(tabname, image_path, files){
|
||||
return [
|
||||
gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"),
|
||||
image_path,
|
||||
files
|
||||
];
|
||||
}
|
||||
|
||||
function images_history_delete(del_num, tabname, img_path, img_file_name, page_index, filenames, image_index){
|
||||
image_index = parseInt(image_index);
|
||||
var tab = gradioApp().getElementById(tabname + '_images_history');
|
||||
var set_btn = tab.querySelector(".images_history_set_index");
|
||||
var buttons = [];
|
||||
tab.querySelectorAll(".gallery-item").forEach(function(e){
|
||||
if (e.style.display != 'none'){
|
||||
buttons.push(e);
|
||||
}
|
||||
});
|
||||
var img_num = buttons.length / 2;
|
||||
if (img_num <= del_num){
|
||||
setTimeout(function(tabname){
|
||||
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
|
||||
}, 30, tabname);
|
||||
} else {
|
||||
var next_img
|
||||
for (var i = 0; i < del_num; i++){
|
||||
if (image_index + i < image_index + img_num){
|
||||
buttons[image_index + i].style.display = 'none';
|
||||
buttons[image_index + img_num + 1].style.display = 'none';
|
||||
next_img = image_index + i + 1
|
||||
}
|
||||
}
|
||||
var bnt;
|
||||
if (next_img >= img_num){
|
||||
btn = buttons[image_index - del_num];
|
||||
} else {
|
||||
btn = buttons[next_img];
|
||||
}
|
||||
setTimeout(function(btn){btn.click()}, 30, btn);
|
||||
}
|
||||
images_history_disabled_del();
|
||||
return [del_num, tabname, img_path, img_file_name, page_index, filenames, image_index];
|
||||
}
|
||||
|
||||
function images_history_turnpage(img_path, page_index, image_index, tabname){
|
||||
var buttons = gradioApp().getElementById(tabname + '_images_history').querySelectorAll(".gallery-item");
|
||||
buttons.forEach(function(elem) {
|
||||
elem.style.display = 'block';
|
||||
})
|
||||
return [img_path, page_index, image_index, tabname];
|
||||
}
|
||||
|
||||
function images_history_enable_del_buttons(){
|
||||
gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
|
||||
btn.removeAttribute('disabled');
|
||||
})
|
||||
}
|
||||
|
||||
function images_history_init(){
|
||||
var load_txt2img_button = gradioApp().getElementById('txt2img_images_history_renew_page')
|
||||
if (load_txt2img_button){
|
||||
for (var i in images_history_tab_list ){
|
||||
tab = images_history_tab_list[i];
|
||||
gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor");
|
||||
gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index");
|
||||
gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button");
|
||||
gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery");
|
||||
|
||||
}
|
||||
var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div");
|
||||
tabs_box.setAttribute("id", "images_history_tab");
|
||||
var tab_btns = tabs_box.querySelectorAll("button");
|
||||
for (var i in images_history_tab_list){
|
||||
var tabname = images_history_tab_list[i]
|
||||
tab_btns[i].setAttribute("tabname", tabname);
|
||||
|
||||
// this refreshes history upon tab switch
|
||||
// until the history is known to work well, which is not the case now, we do not do this at startup
|
||||
//tab_btns[i].addEventListener('click', images_history_click_tab);
|
||||
}
|
||||
tabs_box.classList.add(images_history_tab_list[0]);
|
||||
|
||||
// same as above, at page load
|
||||
//load_txt2img_button.click();
|
||||
} else {
|
||||
setTimeout(images_history_init, 500);
|
||||
}
|
||||
}
|
||||
|
||||
var images_history_tab_list = ["txt2img", "img2img", "extras"];
|
||||
setTimeout(images_history_init, 500);
|
||||
document.addEventListener("DOMContentLoaded", function() {
|
||||
var mutationObserver = new MutationObserver(function(m){
|
||||
for (var i in images_history_tab_list ){
|
||||
let tabname = images_history_tab_list[i]
|
||||
var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item');
|
||||
buttons.forEach(function(bnt){
|
||||
bnt.addEventListener('click', images_history_click_image, true);
|
||||
});
|
||||
|
||||
// same as load_txt2img_button.click() above
|
||||
/*
|
||||
var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg");
|
||||
if (cls_btn){
|
||||
cls_btn.addEventListener('click', function(){
|
||||
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
|
||||
}, false);
|
||||
}*/
|
||||
|
||||
}
|
||||
});
|
||||
mutationObserver.observe( gradioApp(), { childList:true, subtree:true });
|
||||
|
||||
});
|
||||
|
||||
|
||||
Binary file not shown.
@ -0,0 +1,267 @@
|
||||
import inspect
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from typing import Any, Optional
|
||||
from typing_extensions import Literal
|
||||
from inflection import underscore
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
||||
from modules.shared import sd_upscalers, opts, parser
|
||||
from typing import Dict, List
|
||||
|
||||
API_NOT_ALLOWED = [
|
||||
"self",
|
||||
"kwargs",
|
||||
"sd_model",
|
||||
"outpath_samples",
|
||||
"outpath_grids",
|
||||
"sampler_index",
|
||||
"do_not_save_samples",
|
||||
"do_not_save_grid",
|
||||
"extra_generation_params",
|
||||
"overlay_images",
|
||||
"do_not_reload_embeddings",
|
||||
"seed_enable_extras",
|
||||
"prompt_for_display",
|
||||
"sampler_noise_scheduler_override",
|
||||
"ddim_discretize"
|
||||
]
|
||||
|
||||
class ModelDef(BaseModel):
|
||||
"""Assistance Class for Pydantic Dynamic Model Generation"""
|
||||
|
||||
field: str
|
||||
field_alias: str
|
||||
field_type: Any
|
||||
field_value: Any
|
||||
field_exclude: bool = False
|
||||
|
||||
|
||||
class PydanticModelGenerator:
|
||||
"""
|
||||
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
|
||||
source_data is a snapshot of the default values produced by the class
|
||||
params are the names of the actual keys required by __init__
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = None,
|
||||
class_instance = None,
|
||||
additional_fields = None,
|
||||
):
|
||||
def field_type_generator(k, v):
|
||||
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
||||
# print(k, v.annotation, v.default)
|
||||
field_type = v.annotation
|
||||
|
||||
return Optional[field_type]
|
||||
|
||||
def merge_class_params(class_):
|
||||
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
|
||||
parameters = {}
|
||||
for classes in all_classes:
|
||||
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
||||
return parameters
|
||||
|
||||
|
||||
self._model_name = model_name
|
||||
self._class_data = merge_class_params(class_instance)
|
||||
|
||||
self._model_def = [
|
||||
ModelDef(
|
||||
field=underscore(k),
|
||||
field_alias=k,
|
||||
field_type=field_type_generator(k, v),
|
||||
field_value=v.default
|
||||
)
|
||||
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||
]
|
||||
|
||||
for fields in additional_fields:
|
||||
self._model_def.append(ModelDef(
|
||||
field=underscore(fields["key"]),
|
||||
field_alias=fields["key"],
|
||||
field_type=fields["type"],
|
||||
field_value=fields["default"],
|
||||
field_exclude=fields["exclude"] if "exclude" in fields else False))
|
||||
|
||||
def generate_model(self):
|
||||
"""
|
||||
Creates a pydantic BaseModel
|
||||
from the json and overrides provided at initialization
|
||||
"""
|
||||
fields = {
|
||||
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
|
||||
}
|
||||
DynamicModel = create_model(self._model_name, **fields)
|
||||
DynamicModel.__config__.allow_population_by_field_name = True
|
||||
DynamicModel.__config__.allow_mutation = True
|
||||
return DynamicModel
|
||||
|
||||
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingTxt2Img",
|
||||
StableDiffusionProcessingTxt2Img,
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
|
||||
).generate_model()
|
||||
|
||||
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingImg2Img",
|
||||
StableDiffusionProcessingImg2Img,
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
|
||||
).generate_model()
|
||||
|
||||
class TextToImageResponse(BaseModel):
|
||||
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: dict
|
||||
info: str
|
||||
|
||||
class ImageToImageResponse(BaseModel):
|
||||
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: dict
|
||||
info: str
|
||||
|
||||
class ExtrasBaseRequest(BaseModel):
|
||||
resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
|
||||
show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
|
||||
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
|
||||
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
|
||||
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
|
||||
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.")
|
||||
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
|
||||
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
|
||||
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
|
||||
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
|
||||
upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?")
|
||||
|
||||
class ExtraBaseResponse(BaseModel):
|
||||
html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
|
||||
|
||||
class ExtrasSingleImageRequest(ExtrasBaseRequest):
|
||||
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||
|
||||
class ExtrasSingleImageResponse(ExtraBaseResponse):
|
||||
image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
|
||||
class FileData(BaseModel):
|
||||
data: str = Field(title="File data", description="Base64 representation of the file")
|
||||
name: str = Field(title="File name")
|
||||
|
||||
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
|
||||
imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
|
||||
|
||||
class ExtrasBatchImagesResponse(ExtraBaseResponse):
|
||||
images: List[str] = Field(title="Images", description="The generated images in base64 format.")
|
||||
|
||||
class PNGInfoRequest(BaseModel):
|
||||
image: str = Field(title="Image", description="The base64 encoded PNG image")
|
||||
|
||||
class PNGInfoResponse(BaseModel):
|
||||
info: str = Field(title="Image info", description="A string with the parameters used to generate the image")
|
||||
items: dict = Field(title="Items", description="An object containing all the info the image had")
|
||||
|
||||
class ProgressRequest(BaseModel):
|
||||
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
|
||||
|
||||
class ProgressResponse(BaseModel):
|
||||
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
|
||||
eta_relative: float = Field(title="ETA in secs")
|
||||
state: dict = Field(title="State", description="The current state snapshot")
|
||||
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
||||
textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
|
||||
|
||||
class InterrogateRequest(BaseModel):
|
||||
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||
model: str = Field(default="clip", title="Model", description="The interrogate model used.")
|
||||
|
||||
class InterrogateResponse(BaseModel):
|
||||
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
|
||||
|
||||
class TrainResponse(BaseModel):
|
||||
info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.")
|
||||
|
||||
class CreateResponse(BaseModel):
|
||||
info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
|
||||
|
||||
class PreprocessResponse(BaseModel):
|
||||
info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
|
||||
|
||||
fields = {}
|
||||
for key, metadata in opts.data_labels.items():
|
||||
value = opts.data.get(key)
|
||||
optType = opts.typemap.get(type(metadata.default), type(value))
|
||||
|
||||
if (metadata is not None):
|
||||
fields.update({key: (Optional[optType], Field(
|
||||
default=metadata.default ,description=metadata.label))})
|
||||
else:
|
||||
fields.update({key: (Optional[optType], Field())})
|
||||
|
||||
OptionsModel = create_model("Options", **fields)
|
||||
|
||||
flags = {}
|
||||
_options = vars(parser)['_option_string_actions']
|
||||
for key in _options:
|
||||
if(_options[key].dest != 'help'):
|
||||
flag = _options[key]
|
||||
_type = str
|
||||
if _options[key].default is not None: _type = type(_options[key].default)
|
||||
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
|
||||
|
||||
FlagsModel = create_model("Flags", **flags)
|
||||
|
||||
class SamplerItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
aliases: List[str] = Field(title="Aliases")
|
||||
options: Dict[str, str] = Field(title="Options")
|
||||
|
||||
class UpscalerItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
model_name: Optional[str] = Field(title="Model Name")
|
||||
model_path: Optional[str] = Field(title="Path")
|
||||
model_url: Optional[str] = Field(title="URL")
|
||||
|
||||
class SDModelItem(BaseModel):
|
||||
title: str = Field(title="Title")
|
||||
model_name: str = Field(title="Model Name")
|
||||
hash: str = Field(title="Hash")
|
||||
filename: str = Field(title="Filename")
|
||||
config: str = Field(title="Config file")
|
||||
|
||||
class HypernetworkItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
path: Optional[str] = Field(title="Path")
|
||||
|
||||
class FaceRestorerItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
cmd_dir: Optional[str] = Field(title="Path")
|
||||
|
||||
class RealesrganItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
path: Optional[str] = Field(title="Path")
|
||||
scale: Optional[int] = Field(title="Scale")
|
||||
|
||||
class PromptStyleItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
prompt: Optional[str] = Field(title="Prompt")
|
||||
negative_prompt: Optional[str] = Field(title="Negative Prompt")
|
||||
|
||||
class ArtistItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
score: float = Field(title="Score")
|
||||
category: str = Field(title="Category")
|
||||
|
||||
class EmbeddingItem(BaseModel):
|
||||
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
|
||||
sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
|
||||
sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
|
||||
shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
|
||||
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
|
||||
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
|
||||
skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
|
||||
|
||||
class MemoryResponse(BaseModel):
|
||||
ram: dict = Field(title="RAM", description="System memory stats")
|
||||
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
|
||||
@ -1,99 +0,0 @@
|
||||
from inflection import underscore
|
||||
from typing import Any, Dict, Optional
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img
|
||||
import inspect
|
||||
|
||||
|
||||
API_NOT_ALLOWED = [
|
||||
"self",
|
||||
"kwargs",
|
||||
"sd_model",
|
||||
"outpath_samples",
|
||||
"outpath_grids",
|
||||
"sampler_index",
|
||||
"do_not_save_samples",
|
||||
"do_not_save_grid",
|
||||
"extra_generation_params",
|
||||
"overlay_images",
|
||||
"do_not_reload_embeddings",
|
||||
"seed_enable_extras",
|
||||
"prompt_for_display",
|
||||
"sampler_noise_scheduler_override",
|
||||
"ddim_discretize"
|
||||
]
|
||||
|
||||
class ModelDef(BaseModel):
|
||||
"""Assistance Class for Pydantic Dynamic Model Generation"""
|
||||
|
||||
field: str
|
||||
field_alias: str
|
||||
field_type: Any
|
||||
field_value: Any
|
||||
|
||||
|
||||
class PydanticModelGenerator:
|
||||
"""
|
||||
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
|
||||
source_data is a snapshot of the default values produced by the class
|
||||
params are the names of the actual keys required by __init__
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = None,
|
||||
class_instance = None,
|
||||
additional_fields = None,
|
||||
):
|
||||
def field_type_generator(k, v):
|
||||
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
||||
# print(k, v.annotation, v.default)
|
||||
field_type = v.annotation
|
||||
|
||||
return Optional[field_type]
|
||||
|
||||
def merge_class_params(class_):
|
||||
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
|
||||
parameters = {}
|
||||
for classes in all_classes:
|
||||
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
||||
return parameters
|
||||
|
||||
|
||||
self._model_name = model_name
|
||||
self._class_data = merge_class_params(class_instance)
|
||||
self._model_def = [
|
||||
ModelDef(
|
||||
field=underscore(k),
|
||||
field_alias=k,
|
||||
field_type=field_type_generator(k, v),
|
||||
field_value=v.default
|
||||
)
|
||||
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||
]
|
||||
|
||||
for fields in additional_fields:
|
||||
self._model_def.append(ModelDef(
|
||||
field=underscore(fields["key"]),
|
||||
field_alias=fields["key"],
|
||||
field_type=fields["type"],
|
||||
field_value=fields["default"]))
|
||||
|
||||
def generate_model(self):
|
||||
"""
|
||||
Creates a pydantic BaseModel
|
||||
from the json and overrides provided at initialization
|
||||
"""
|
||||
fields = {
|
||||
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
|
||||
}
|
||||
DynamicModel = create_model(self._model_name, **fields)
|
||||
DynamicModel.__config__.allow_population_by_field_name = True
|
||||
DynamicModel.__config__.allow_mutation = True
|
||||
return DynamicModel
|
||||
|
||||
StableDiffusionProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingTxt2Img",
|
||||
StableDiffusionProcessingTxt2Img,
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}]
|
||||
).generate_model()
|
||||
@ -1,76 +0,0 @@
|
||||
import os.path
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
import torch
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.upscaler
|
||||
from modules import devices, modelloader
|
||||
from modules.bsrgan_model_arch import RRDBNet
|
||||
|
||||
|
||||
class UpscalerBSRGAN(modules.upscaler.Upscaler):
|
||||
def __init__(self, dirname):
|
||||
self.name = "BSRGAN"
|
||||
self.model_name = "BSRGAN 4x"
|
||||
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
|
||||
self.user_path = dirname
|
||||
super().__init__()
|
||||
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
||||
scalers = []
|
||||
if len(model_paths) == 0:
|
||||
scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
|
||||
scalers.append(scaler_data)
|
||||
for file in model_paths:
|
||||
if "http" in file:
|
||||
name = self.model_name
|
||||
else:
|
||||
name = modelloader.friendly_name(file)
|
||||
try:
|
||||
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
||||
scalers.append(scaler_data)
|
||||
except Exception:
|
||||
print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
self.scalers = scalers
|
||||
|
||||
def do_upscale(self, img: PIL.Image, selected_file):
|
||||
torch.cuda.empty_cache()
|
||||
model = self.load_model(selected_file)
|
||||
if model is None:
|
||||
return img
|
||||
model.to(devices.device_bsrgan)
|
||||
torch.cuda.empty_cache()
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.moveaxis(img, 2, 0) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(devices.device_bsrgan)
|
||||
with torch.no_grad():
|
||||
output = model(img)
|
||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output = 255. * np.moveaxis(output, 0, 2)
|
||||
output = output.astype(np.uint8)
|
||||
output = output[:, :, ::-1]
|
||||
torch.cuda.empty_cache()
|
||||
return PIL.Image.fromarray(output, 'RGB')
|
||||
|
||||
def load_model(self, path: str):
|
||||
if "http" in path:
|
||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
||||
progress=True)
|
||||
else:
|
||||
filename = path
|
||||
if not os.path.exists(filename) or filename is None:
|
||||
print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
|
||||
return None
|
||||
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
|
||||
model.load_state_dict(torch.load(filename), strict=True)
|
||||
model.eval()
|
||||
for k, v in model.named_parameters():
|
||||
v.requires_grad = False
|
||||
return model
|
||||
|
||||
@ -1,102 +0,0 @@
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
|
||||
|
||||
def initialize_weights(net_l, scale=1):
|
||||
if not isinstance(net_l, list):
|
||||
net_l = [net_l]
|
||||
for net in net_l:
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
m.weight.data *= scale # for residual block
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
m.weight.data *= scale
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
|
||||
def make_layer(block, n_layers):
|
||||
layers = []
|
||||
for _ in range(n_layers):
|
||||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
# initialization
|
||||
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
'''Residual in Residual Dense Block'''
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
self.sf = sf
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
if self.sf==4:
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
fea = fea + trunk
|
||||
|
||||
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
if self.sf==4:
|
||||
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
return out
|
||||
@ -0,0 +1,98 @@
|
||||
import html
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
import time
|
||||
|
||||
from modules import shared
|
||||
|
||||
queue_lock = threading.Lock()
|
||||
|
||||
|
||||
def wrap_queued_call(func):
|
||||
def f(*args, **kwargs):
|
||||
with queue_lock:
|
||||
res = func(*args, **kwargs)
|
||||
|
||||
return res
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||
def f(*args, **kwargs):
|
||||
|
||||
shared.state.begin()
|
||||
|
||||
with queue_lock:
|
||||
res = func(*args, **kwargs)
|
||||
|
||||
shared.state.end()
|
||||
|
||||
return res
|
||||
|
||||
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
|
||||
|
||||
|
||||
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
||||
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
||||
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
|
||||
if run_memmon:
|
||||
shared.mem_mon.monitor()
|
||||
t = time.perf_counter()
|
||||
|
||||
try:
|
||||
res = list(func(*args, **kwargs))
|
||||
except Exception as e:
|
||||
# When printing out our debug argument list, do not print out more than a MB of text
|
||||
max_debug_str_len = 131072 # (1024*1024)/8
|
||||
|
||||
print("Error completing request", file=sys.stderr)
|
||||
argStr = f"Arguments: {str(args)} {str(kwargs)}"
|
||||
print(argStr[:max_debug_str_len], file=sys.stderr)
|
||||
if len(argStr) > max_debug_str_len:
|
||||
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
||||
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
shared.state.job = ""
|
||||
shared.state.job_count = 0
|
||||
|
||||
if extra_outputs_array is None:
|
||||
extra_outputs_array = [None, '']
|
||||
|
||||
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
|
||||
|
||||
shared.state.skipped = False
|
||||
shared.state.interrupted = False
|
||||
shared.state.job_count = 0
|
||||
|
||||
if not add_stats:
|
||||
return tuple(res)
|
||||
|
||||
elapsed = time.perf_counter() - t
|
||||
elapsed_m = int(elapsed // 60)
|
||||
elapsed_s = elapsed % 60
|
||||
elapsed_text = f"{elapsed_s:.2f}s"
|
||||
if elapsed_m > 0:
|
||||
elapsed_text = f"{elapsed_m}m "+elapsed_text
|
||||
|
||||
if run_memmon:
|
||||
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
|
||||
active_peak = mem_stats['active_peak']
|
||||
reserved_peak = mem_stats['reserved_peak']
|
||||
sys_peak = mem_stats['system_peak']
|
||||
sys_total = mem_stats['total']
|
||||
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
|
||||
|
||||
vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
|
||||
else:
|
||||
vram_html = ''
|
||||
|
||||
# last item is always HTML
|
||||
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
|
||||
|
||||
return tuple(res)
|
||||
|
||||
return f
|
||||
|
||||
@ -1,172 +1,99 @@
|
||||
import os.path
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
import multiprocessing
|
||||
import time
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from modules import modelloader, paths, deepbooru_model, devices, images, shared
|
||||
|
||||
re_special = re.compile(r'([\\()])')
|
||||
|
||||
def get_deepbooru_tags(pil_image):
|
||||
"""
|
||||
This method is for running only one image at a time for simple use. Used to the img2img interrogate.
|
||||
"""
|
||||
from modules import shared # prevents circular reference
|
||||
|
||||
try:
|
||||
create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts())
|
||||
return get_tags_from_process(pil_image)
|
||||
finally:
|
||||
release_process()
|
||||
|
||||
|
||||
OPT_INCLUDE_RANKS = "include_ranks"
|
||||
def create_deepbooru_opts():
|
||||
from modules import shared
|
||||
|
||||
return {
|
||||
"use_spaces": shared.opts.deepbooru_use_spaces,
|
||||
"use_escape": shared.opts.deepbooru_escape,
|
||||
"alpha_sort": shared.opts.deepbooru_sort_alpha,
|
||||
OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks,
|
||||
}
|
||||
|
||||
|
||||
def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts):
|
||||
model, tags = get_deepbooru_tags_model()
|
||||
while True: # while process is running, keep monitoring queue for new image
|
||||
pil_image = queue.get()
|
||||
if pil_image == "QUIT":
|
||||
break
|
||||
else:
|
||||
deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts)
|
||||
|
||||
|
||||
def create_deepbooru_process(threshold, deepbooru_opts):
|
||||
"""
|
||||
Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
|
||||
to be processed in a row without reloading the model or creating a new process. To return the data, a shared
|
||||
dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned
|
||||
to the dictionary and the method adding the image to the queue should wait for this value to be updated with
|
||||
the tags.
|
||||
"""
|
||||
from modules import shared # prevents circular reference
|
||||
shared.deepbooru_process_manager = multiprocessing.Manager()
|
||||
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
|
||||
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
|
||||
shared.deepbooru_process_return["value"] = -1
|
||||
shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
|
||||
shared.deepbooru_process.start()
|
||||
|
||||
|
||||
def get_tags_from_process(image):
|
||||
from modules import shared
|
||||
|
||||
shared.deepbooru_process_return["value"] = -1
|
||||
shared.deepbooru_process_queue.put(image)
|
||||
while shared.deepbooru_process_return["value"] == -1:
|
||||
time.sleep(0.2)
|
||||
caption = shared.deepbooru_process_return["value"]
|
||||
shared.deepbooru_process_return["value"] = -1
|
||||
|
||||
return caption
|
||||
|
||||
|
||||
def release_process():
|
||||
"""
|
||||
Stops the deepbooru process to return used memory
|
||||
"""
|
||||
from modules import shared # prevents circular reference
|
||||
shared.deepbooru_process_queue.put("QUIT")
|
||||
shared.deepbooru_process.join()
|
||||
shared.deepbooru_process_queue = None
|
||||
shared.deepbooru_process = None
|
||||
shared.deepbooru_process_return = None
|
||||
shared.deepbooru_process_manager = None
|
||||
|
||||
def get_deepbooru_tags_model():
|
||||
import deepdanbooru as dd
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
this_folder = os.path.dirname(__file__)
|
||||
model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
|
||||
if not os.path.exists(os.path.join(model_path, 'project.json')):
|
||||
# there is no point importing these every time
|
||||
import zipfile
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
load_file_from_url(
|
||||
r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
|
||||
model_path)
|
||||
with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
|
||||
zip_ref.extractall(model_path)
|
||||
os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
|
||||
|
||||
tags = dd.project.load_tags_from_project(model_path)
|
||||
model = dd.project.load_model_from_project(
|
||||
model_path, compile_model=False
|
||||
)
|
||||
return model, tags
|
||||
|
||||
|
||||
def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts):
|
||||
import deepdanbooru as dd
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
alpha_sort = deepbooru_opts['alpha_sort']
|
||||
use_spaces = deepbooru_opts['use_spaces']
|
||||
use_escape = deepbooru_opts['use_escape']
|
||||
include_ranks = deepbooru_opts['include_ranks']
|
||||
|
||||
width = model.input_shape[2]
|
||||
height = model.input_shape[1]
|
||||
image = np.array(pil_image)
|
||||
image = tf.image.resize(
|
||||
image,
|
||||
size=(height, width),
|
||||
method=tf.image.ResizeMethod.AREA,
|
||||
preserve_aspect_ratio=True,
|
||||
|
||||
class DeepDanbooru:
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
|
||||
def load(self):
|
||||
if self.model is not None:
|
||||
return
|
||||
|
||||
files = modelloader.load_models(
|
||||
model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
|
||||
model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
|
||||
ext_filter=[".pt"],
|
||||
download_name='model-resnet_custom_v3.pt',
|
||||
)
|
||||
image = image.numpy() # EagerTensor to np.array
|
||||
image = dd.image.transform_and_pad_image(image, width, height)
|
||||
image = image / 255.0
|
||||
image_shape = image.shape
|
||||
image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
|
||||
|
||||
y = model.predict(image)[0]
|
||||
self.model = deepbooru_model.DeepDanbooruModel()
|
||||
self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
|
||||
|
||||
self.model.eval()
|
||||
self.model.to(devices.cpu, devices.dtype)
|
||||
|
||||
def start(self):
|
||||
self.load()
|
||||
self.model.to(devices.device)
|
||||
|
||||
def stop(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
self.model.to(devices.cpu)
|
||||
devices.torch_gc()
|
||||
|
||||
def tag(self, pil_image):
|
||||
self.start()
|
||||
res = self.tag_multi(pil_image)
|
||||
self.stop()
|
||||
|
||||
result_dict = {}
|
||||
return res
|
||||
|
||||
for i, tag in enumerate(tags):
|
||||
result_dict[tag] = y[i]
|
||||
def tag_multi(self, pil_image, force_disable_ranks=False):
|
||||
threshold = shared.opts.interrogate_deepbooru_score_threshold
|
||||
use_spaces = shared.opts.deepbooru_use_spaces
|
||||
use_escape = shared.opts.deepbooru_escape
|
||||
alpha_sort = shared.opts.deepbooru_sort_alpha
|
||||
include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
|
||||
|
||||
pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
|
||||
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
|
||||
|
||||
with torch.no_grad(), devices.autocast():
|
||||
x = torch.from_numpy(a).to(devices.device)
|
||||
y = self.model(x)[0].detach().cpu().numpy()
|
||||
|
||||
probability_dict = {}
|
||||
|
||||
for tag, probability in zip(self.model.tags, y):
|
||||
if probability < threshold:
|
||||
continue
|
||||
|
||||
unsorted_tags_in_theshold = []
|
||||
result_tags_print = []
|
||||
for tag in tags:
|
||||
if result_dict[tag] >= threshold:
|
||||
if tag.startswith("rating:"):
|
||||
continue
|
||||
unsorted_tags_in_theshold.append((result_dict[tag], tag))
|
||||
result_tags_print.append(f'{result_dict[tag]} {tag}')
|
||||
|
||||
# sort tags
|
||||
result_tags_out = []
|
||||
sort_ndx = 0
|
||||
probability_dict[tag] = probability
|
||||
|
||||
if alpha_sort:
|
||||
sort_ndx = 1
|
||||
tags = sorted(probability_dict)
|
||||
else:
|
||||
tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
|
||||
|
||||
res = []
|
||||
|
||||
# sort by reverse by likelihood and normal for alpha, and format tag text as requested
|
||||
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
|
||||
for weight, tag in unsorted_tags_in_theshold:
|
||||
filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
|
||||
|
||||
for tag in [x for x in tags if x not in filtertags]:
|
||||
probability = probability_dict[tag]
|
||||
tag_outformat = tag
|
||||
if use_spaces:
|
||||
tag_outformat = tag_outformat.replace('_', ' ')
|
||||
if use_escape:
|
||||
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
|
||||
if include_ranks:
|
||||
tag_outformat = f"({tag_outformat}:{weight:.3f})"
|
||||
tag_outformat = f"({tag_outformat}:{probability:.3f})"
|
||||
|
||||
res.append(tag_outformat)
|
||||
|
||||
result_tags_out.append(tag_outformat)
|
||||
return ", ".join(res)
|
||||
|
||||
print('\n'.join(sorted(result_tags_print, reverse=True)))
|
||||
|
||||
return ', '.join(result_tags_out)
|
||||
model = DeepDanbooru()
|
||||
|
||||
@ -0,0 +1,676 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
|
||||
|
||||
|
||||
class DeepDanbooruModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(DeepDanbooruModel, self).__init__()
|
||||
|
||||
self.tags = []
|
||||
|
||||
self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
|
||||
self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
|
||||
self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||
self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
|
||||
self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
||||
self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||
self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
||||
self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
||||
self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||
self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
||||
self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
||||
self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||
self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
|
||||
self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
|
||||
self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
|
||||
self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
|
||||
self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
|
||||
self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
||||
self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
||||
self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
|
||||
self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
|
||||
self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
|
||||
self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
|
||||
self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
||||
self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
||||
self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
||||
self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
||||
self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
||||
self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
||||
self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
||||
self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
|
||||
self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
|
||||
self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
|
||||
self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
||||
self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
||||
self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
||||
self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
||||
self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
||||
self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
||||
self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
||||
self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
|
||||
|
||||
def forward(self, *inputs):
|
||||
t_358, = inputs
|
||||
t_359 = t_358.permute(*[0, 3, 1, 2])
|
||||
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
|
||||
t_360 = self.n_Conv_0(t_359_padded)
|
||||
t_361 = F.relu(t_360)
|
||||
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
|
||||
t_362 = self.n_MaxPool_0(t_361)
|
||||
t_363 = self.n_Conv_1(t_362)
|
||||
t_364 = self.n_Conv_2(t_362)
|
||||
t_365 = F.relu(t_364)
|
||||
t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
|
||||
t_366 = self.n_Conv_3(t_365_padded)
|
||||
t_367 = F.relu(t_366)
|
||||
t_368 = self.n_Conv_4(t_367)
|
||||
t_369 = torch.add(t_368, t_363)
|
||||
t_370 = F.relu(t_369)
|
||||
t_371 = self.n_Conv_5(t_370)
|
||||
t_372 = F.relu(t_371)
|
||||
t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
|
||||
t_373 = self.n_Conv_6(t_372_padded)
|
||||
t_374 = F.relu(t_373)
|
||||
t_375 = self.n_Conv_7(t_374)
|
||||
t_376 = torch.add(t_375, t_370)
|
||||
t_377 = F.relu(t_376)
|
||||
t_378 = self.n_Conv_8(t_377)
|
||||
t_379 = F.relu(t_378)
|
||||
t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
|
||||
t_380 = self.n_Conv_9(t_379_padded)
|
||||
t_381 = F.relu(t_380)
|
||||
t_382 = self.n_Conv_10(t_381)
|
||||
t_383 = torch.add(t_382, t_377)
|
||||
t_384 = F.relu(t_383)
|
||||
t_385 = self.n_Conv_11(t_384)
|
||||
t_386 = self.n_Conv_12(t_384)
|
||||
t_387 = F.relu(t_386)
|
||||
t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
|
||||
t_388 = self.n_Conv_13(t_387_padded)
|
||||
t_389 = F.relu(t_388)
|
||||
t_390 = self.n_Conv_14(t_389)
|
||||
t_391 = torch.add(t_390, t_385)
|
||||
t_392 = F.relu(t_391)
|
||||
t_393 = self.n_Conv_15(t_392)
|
||||
t_394 = F.relu(t_393)
|
||||
t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
|
||||
t_395 = self.n_Conv_16(t_394_padded)
|
||||
t_396 = F.relu(t_395)
|
||||
t_397 = self.n_Conv_17(t_396)
|
||||
t_398 = torch.add(t_397, t_392)
|
||||
t_399 = F.relu(t_398)
|
||||
t_400 = self.n_Conv_18(t_399)
|
||||
t_401 = F.relu(t_400)
|
||||
t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
|
||||
t_402 = self.n_Conv_19(t_401_padded)
|
||||
t_403 = F.relu(t_402)
|
||||
t_404 = self.n_Conv_20(t_403)
|
||||
t_405 = torch.add(t_404, t_399)
|
||||
t_406 = F.relu(t_405)
|
||||
t_407 = self.n_Conv_21(t_406)
|
||||
t_408 = F.relu(t_407)
|
||||
t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
|
||||
t_409 = self.n_Conv_22(t_408_padded)
|
||||
t_410 = F.relu(t_409)
|
||||
t_411 = self.n_Conv_23(t_410)
|
||||
t_412 = torch.add(t_411, t_406)
|
||||
t_413 = F.relu(t_412)
|
||||
t_414 = self.n_Conv_24(t_413)
|
||||
t_415 = F.relu(t_414)
|
||||
t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
|
||||
t_416 = self.n_Conv_25(t_415_padded)
|
||||
t_417 = F.relu(t_416)
|
||||
t_418 = self.n_Conv_26(t_417)
|
||||
t_419 = torch.add(t_418, t_413)
|
||||
t_420 = F.relu(t_419)
|
||||
t_421 = self.n_Conv_27(t_420)
|
||||
t_422 = F.relu(t_421)
|
||||
t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
|
||||
t_423 = self.n_Conv_28(t_422_padded)
|
||||
t_424 = F.relu(t_423)
|
||||
t_425 = self.n_Conv_29(t_424)
|
||||
t_426 = torch.add(t_425, t_420)
|
||||
t_427 = F.relu(t_426)
|
||||
t_428 = self.n_Conv_30(t_427)
|
||||
t_429 = F.relu(t_428)
|
||||
t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
|
||||
t_430 = self.n_Conv_31(t_429_padded)
|
||||
t_431 = F.relu(t_430)
|
||||
t_432 = self.n_Conv_32(t_431)
|
||||
t_433 = torch.add(t_432, t_427)
|
||||
t_434 = F.relu(t_433)
|
||||
t_435 = self.n_Conv_33(t_434)
|
||||
t_436 = F.relu(t_435)
|
||||
t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
|
||||
t_437 = self.n_Conv_34(t_436_padded)
|
||||
t_438 = F.relu(t_437)
|
||||
t_439 = self.n_Conv_35(t_438)
|
||||
t_440 = torch.add(t_439, t_434)
|
||||
t_441 = F.relu(t_440)
|
||||
t_442 = self.n_Conv_36(t_441)
|
||||
t_443 = self.n_Conv_37(t_441)
|
||||
t_444 = F.relu(t_443)
|
||||
t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
|
||||
t_445 = self.n_Conv_38(t_444_padded)
|
||||
t_446 = F.relu(t_445)
|
||||
t_447 = self.n_Conv_39(t_446)
|
||||
t_448 = torch.add(t_447, t_442)
|
||||
t_449 = F.relu(t_448)
|
||||
t_450 = self.n_Conv_40(t_449)
|
||||
t_451 = F.relu(t_450)
|
||||
t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
|
||||
t_452 = self.n_Conv_41(t_451_padded)
|
||||
t_453 = F.relu(t_452)
|
||||
t_454 = self.n_Conv_42(t_453)
|
||||
t_455 = torch.add(t_454, t_449)
|
||||
t_456 = F.relu(t_455)
|
||||
t_457 = self.n_Conv_43(t_456)
|
||||
t_458 = F.relu(t_457)
|
||||
t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
|
||||
t_459 = self.n_Conv_44(t_458_padded)
|
||||
t_460 = F.relu(t_459)
|
||||
t_461 = self.n_Conv_45(t_460)
|
||||
t_462 = torch.add(t_461, t_456)
|
||||
t_463 = F.relu(t_462)
|
||||
t_464 = self.n_Conv_46(t_463)
|
||||
t_465 = F.relu(t_464)
|
||||
t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
|
||||
t_466 = self.n_Conv_47(t_465_padded)
|
||||
t_467 = F.relu(t_466)
|
||||
t_468 = self.n_Conv_48(t_467)
|
||||
t_469 = torch.add(t_468, t_463)
|
||||
t_470 = F.relu(t_469)
|
||||
t_471 = self.n_Conv_49(t_470)
|
||||
t_472 = F.relu(t_471)
|
||||
t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
|
||||
t_473 = self.n_Conv_50(t_472_padded)
|
||||
t_474 = F.relu(t_473)
|
||||
t_475 = self.n_Conv_51(t_474)
|
||||
t_476 = torch.add(t_475, t_470)
|
||||
t_477 = F.relu(t_476)
|
||||
t_478 = self.n_Conv_52(t_477)
|
||||
t_479 = F.relu(t_478)
|
||||
t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
|
||||
t_480 = self.n_Conv_53(t_479_padded)
|
||||
t_481 = F.relu(t_480)
|
||||
t_482 = self.n_Conv_54(t_481)
|
||||
t_483 = torch.add(t_482, t_477)
|
||||
t_484 = F.relu(t_483)
|
||||
t_485 = self.n_Conv_55(t_484)
|
||||
t_486 = F.relu(t_485)
|
||||
t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
|
||||
t_487 = self.n_Conv_56(t_486_padded)
|
||||
t_488 = F.relu(t_487)
|
||||
t_489 = self.n_Conv_57(t_488)
|
||||
t_490 = torch.add(t_489, t_484)
|
||||
t_491 = F.relu(t_490)
|
||||
t_492 = self.n_Conv_58(t_491)
|
||||
t_493 = F.relu(t_492)
|
||||
t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
|
||||
t_494 = self.n_Conv_59(t_493_padded)
|
||||
t_495 = F.relu(t_494)
|
||||
t_496 = self.n_Conv_60(t_495)
|
||||
t_497 = torch.add(t_496, t_491)
|
||||
t_498 = F.relu(t_497)
|
||||
t_499 = self.n_Conv_61(t_498)
|
||||
t_500 = F.relu(t_499)
|
||||
t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
|
||||
t_501 = self.n_Conv_62(t_500_padded)
|
||||
t_502 = F.relu(t_501)
|
||||
t_503 = self.n_Conv_63(t_502)
|
||||
t_504 = torch.add(t_503, t_498)
|
||||
t_505 = F.relu(t_504)
|
||||
t_506 = self.n_Conv_64(t_505)
|
||||
t_507 = F.relu(t_506)
|
||||
t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
|
||||
t_508 = self.n_Conv_65(t_507_padded)
|
||||
t_509 = F.relu(t_508)
|
||||
t_510 = self.n_Conv_66(t_509)
|
||||
t_511 = torch.add(t_510, t_505)
|
||||
t_512 = F.relu(t_511)
|
||||
t_513 = self.n_Conv_67(t_512)
|
||||
t_514 = F.relu(t_513)
|
||||
t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
|
||||
t_515 = self.n_Conv_68(t_514_padded)
|
||||
t_516 = F.relu(t_515)
|
||||
t_517 = self.n_Conv_69(t_516)
|
||||
t_518 = torch.add(t_517, t_512)
|
||||
t_519 = F.relu(t_518)
|
||||
t_520 = self.n_Conv_70(t_519)
|
||||
t_521 = F.relu(t_520)
|
||||
t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
|
||||
t_522 = self.n_Conv_71(t_521_padded)
|
||||
t_523 = F.relu(t_522)
|
||||
t_524 = self.n_Conv_72(t_523)
|
||||
t_525 = torch.add(t_524, t_519)
|
||||
t_526 = F.relu(t_525)
|
||||
t_527 = self.n_Conv_73(t_526)
|
||||
t_528 = F.relu(t_527)
|
||||
t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
|
||||
t_529 = self.n_Conv_74(t_528_padded)
|
||||
t_530 = F.relu(t_529)
|
||||
t_531 = self.n_Conv_75(t_530)
|
||||
t_532 = torch.add(t_531, t_526)
|
||||
t_533 = F.relu(t_532)
|
||||
t_534 = self.n_Conv_76(t_533)
|
||||
t_535 = F.relu(t_534)
|
||||
t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
|
||||
t_536 = self.n_Conv_77(t_535_padded)
|
||||
t_537 = F.relu(t_536)
|
||||
t_538 = self.n_Conv_78(t_537)
|
||||
t_539 = torch.add(t_538, t_533)
|
||||
t_540 = F.relu(t_539)
|
||||
t_541 = self.n_Conv_79(t_540)
|
||||
t_542 = F.relu(t_541)
|
||||
t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
|
||||
t_543 = self.n_Conv_80(t_542_padded)
|
||||
t_544 = F.relu(t_543)
|
||||
t_545 = self.n_Conv_81(t_544)
|
||||
t_546 = torch.add(t_545, t_540)
|
||||
t_547 = F.relu(t_546)
|
||||
t_548 = self.n_Conv_82(t_547)
|
||||
t_549 = F.relu(t_548)
|
||||
t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
|
||||
t_550 = self.n_Conv_83(t_549_padded)
|
||||
t_551 = F.relu(t_550)
|
||||
t_552 = self.n_Conv_84(t_551)
|
||||
t_553 = torch.add(t_552, t_547)
|
||||
t_554 = F.relu(t_553)
|
||||
t_555 = self.n_Conv_85(t_554)
|
||||
t_556 = F.relu(t_555)
|
||||
t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
|
||||
t_557 = self.n_Conv_86(t_556_padded)
|
||||
t_558 = F.relu(t_557)
|
||||
t_559 = self.n_Conv_87(t_558)
|
||||
t_560 = torch.add(t_559, t_554)
|
||||
t_561 = F.relu(t_560)
|
||||
t_562 = self.n_Conv_88(t_561)
|
||||
t_563 = F.relu(t_562)
|
||||
t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
|
||||
t_564 = self.n_Conv_89(t_563_padded)
|
||||
t_565 = F.relu(t_564)
|
||||
t_566 = self.n_Conv_90(t_565)
|
||||
t_567 = torch.add(t_566, t_561)
|
||||
t_568 = F.relu(t_567)
|
||||
t_569 = self.n_Conv_91(t_568)
|
||||
t_570 = F.relu(t_569)
|
||||
t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
|
||||
t_571 = self.n_Conv_92(t_570_padded)
|
||||
t_572 = F.relu(t_571)
|
||||
t_573 = self.n_Conv_93(t_572)
|
||||
t_574 = torch.add(t_573, t_568)
|
||||
t_575 = F.relu(t_574)
|
||||
t_576 = self.n_Conv_94(t_575)
|
||||
t_577 = F.relu(t_576)
|
||||
t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
|
||||
t_578 = self.n_Conv_95(t_577_padded)
|
||||
t_579 = F.relu(t_578)
|
||||
t_580 = self.n_Conv_96(t_579)
|
||||
t_581 = torch.add(t_580, t_575)
|
||||
t_582 = F.relu(t_581)
|
||||
t_583 = self.n_Conv_97(t_582)
|
||||
t_584 = F.relu(t_583)
|
||||
t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
|
||||
t_585 = self.n_Conv_98(t_584_padded)
|
||||
t_586 = F.relu(t_585)
|
||||
t_587 = self.n_Conv_99(t_586)
|
||||
t_588 = self.n_Conv_100(t_582)
|
||||
t_589 = torch.add(t_587, t_588)
|
||||
t_590 = F.relu(t_589)
|
||||
t_591 = self.n_Conv_101(t_590)
|
||||
t_592 = F.relu(t_591)
|
||||
t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
|
||||
t_593 = self.n_Conv_102(t_592_padded)
|
||||
t_594 = F.relu(t_593)
|
||||
t_595 = self.n_Conv_103(t_594)
|
||||
t_596 = torch.add(t_595, t_590)
|
||||
t_597 = F.relu(t_596)
|
||||
t_598 = self.n_Conv_104(t_597)
|
||||
t_599 = F.relu(t_598)
|
||||
t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
|
||||
t_600 = self.n_Conv_105(t_599_padded)
|
||||
t_601 = F.relu(t_600)
|
||||
t_602 = self.n_Conv_106(t_601)
|
||||
t_603 = torch.add(t_602, t_597)
|
||||
t_604 = F.relu(t_603)
|
||||
t_605 = self.n_Conv_107(t_604)
|
||||
t_606 = F.relu(t_605)
|
||||
t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
|
||||
t_607 = self.n_Conv_108(t_606_padded)
|
||||
t_608 = F.relu(t_607)
|
||||
t_609 = self.n_Conv_109(t_608)
|
||||
t_610 = torch.add(t_609, t_604)
|
||||
t_611 = F.relu(t_610)
|
||||
t_612 = self.n_Conv_110(t_611)
|
||||
t_613 = F.relu(t_612)
|
||||
t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
|
||||
t_614 = self.n_Conv_111(t_613_padded)
|
||||
t_615 = F.relu(t_614)
|
||||
t_616 = self.n_Conv_112(t_615)
|
||||
t_617 = torch.add(t_616, t_611)
|
||||
t_618 = F.relu(t_617)
|
||||
t_619 = self.n_Conv_113(t_618)
|
||||
t_620 = F.relu(t_619)
|
||||
t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
|
||||
t_621 = self.n_Conv_114(t_620_padded)
|
||||
t_622 = F.relu(t_621)
|
||||
t_623 = self.n_Conv_115(t_622)
|
||||
t_624 = torch.add(t_623, t_618)
|
||||
t_625 = F.relu(t_624)
|
||||
t_626 = self.n_Conv_116(t_625)
|
||||
t_627 = F.relu(t_626)
|
||||
t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
|
||||
t_628 = self.n_Conv_117(t_627_padded)
|
||||
t_629 = F.relu(t_628)
|
||||
t_630 = self.n_Conv_118(t_629)
|
||||
t_631 = torch.add(t_630, t_625)
|
||||
t_632 = F.relu(t_631)
|
||||
t_633 = self.n_Conv_119(t_632)
|
||||
t_634 = F.relu(t_633)
|
||||
t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
|
||||
t_635 = self.n_Conv_120(t_634_padded)
|
||||
t_636 = F.relu(t_635)
|
||||
t_637 = self.n_Conv_121(t_636)
|
||||
t_638 = torch.add(t_637, t_632)
|
||||
t_639 = F.relu(t_638)
|
||||
t_640 = self.n_Conv_122(t_639)
|
||||
t_641 = F.relu(t_640)
|
||||
t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
|
||||
t_642 = self.n_Conv_123(t_641_padded)
|
||||
t_643 = F.relu(t_642)
|
||||
t_644 = self.n_Conv_124(t_643)
|
||||
t_645 = torch.add(t_644, t_639)
|
||||
t_646 = F.relu(t_645)
|
||||
t_647 = self.n_Conv_125(t_646)
|
||||
t_648 = F.relu(t_647)
|
||||
t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
|
||||
t_649 = self.n_Conv_126(t_648_padded)
|
||||
t_650 = F.relu(t_649)
|
||||
t_651 = self.n_Conv_127(t_650)
|
||||
t_652 = torch.add(t_651, t_646)
|
||||
t_653 = F.relu(t_652)
|
||||
t_654 = self.n_Conv_128(t_653)
|
||||
t_655 = F.relu(t_654)
|
||||
t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
|
||||
t_656 = self.n_Conv_129(t_655_padded)
|
||||
t_657 = F.relu(t_656)
|
||||
t_658 = self.n_Conv_130(t_657)
|
||||
t_659 = torch.add(t_658, t_653)
|
||||
t_660 = F.relu(t_659)
|
||||
t_661 = self.n_Conv_131(t_660)
|
||||
t_662 = F.relu(t_661)
|
||||
t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
|
||||
t_663 = self.n_Conv_132(t_662_padded)
|
||||
t_664 = F.relu(t_663)
|
||||
t_665 = self.n_Conv_133(t_664)
|
||||
t_666 = torch.add(t_665, t_660)
|
||||
t_667 = F.relu(t_666)
|
||||
t_668 = self.n_Conv_134(t_667)
|
||||
t_669 = F.relu(t_668)
|
||||
t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
|
||||
t_670 = self.n_Conv_135(t_669_padded)
|
||||
t_671 = F.relu(t_670)
|
||||
t_672 = self.n_Conv_136(t_671)
|
||||
t_673 = torch.add(t_672, t_667)
|
||||
t_674 = F.relu(t_673)
|
||||
t_675 = self.n_Conv_137(t_674)
|
||||
t_676 = F.relu(t_675)
|
||||
t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
|
||||
t_677 = self.n_Conv_138(t_676_padded)
|
||||
t_678 = F.relu(t_677)
|
||||
t_679 = self.n_Conv_139(t_678)
|
||||
t_680 = torch.add(t_679, t_674)
|
||||
t_681 = F.relu(t_680)
|
||||
t_682 = self.n_Conv_140(t_681)
|
||||
t_683 = F.relu(t_682)
|
||||
t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
|
||||
t_684 = self.n_Conv_141(t_683_padded)
|
||||
t_685 = F.relu(t_684)
|
||||
t_686 = self.n_Conv_142(t_685)
|
||||
t_687 = torch.add(t_686, t_681)
|
||||
t_688 = F.relu(t_687)
|
||||
t_689 = self.n_Conv_143(t_688)
|
||||
t_690 = F.relu(t_689)
|
||||
t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
|
||||
t_691 = self.n_Conv_144(t_690_padded)
|
||||
t_692 = F.relu(t_691)
|
||||
t_693 = self.n_Conv_145(t_692)
|
||||
t_694 = torch.add(t_693, t_688)
|
||||
t_695 = F.relu(t_694)
|
||||
t_696 = self.n_Conv_146(t_695)
|
||||
t_697 = F.relu(t_696)
|
||||
t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
|
||||
t_698 = self.n_Conv_147(t_697_padded)
|
||||
t_699 = F.relu(t_698)
|
||||
t_700 = self.n_Conv_148(t_699)
|
||||
t_701 = torch.add(t_700, t_695)
|
||||
t_702 = F.relu(t_701)
|
||||
t_703 = self.n_Conv_149(t_702)
|
||||
t_704 = F.relu(t_703)
|
||||
t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
|
||||
t_705 = self.n_Conv_150(t_704_padded)
|
||||
t_706 = F.relu(t_705)
|
||||
t_707 = self.n_Conv_151(t_706)
|
||||
t_708 = torch.add(t_707, t_702)
|
||||
t_709 = F.relu(t_708)
|
||||
t_710 = self.n_Conv_152(t_709)
|
||||
t_711 = F.relu(t_710)
|
||||
t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
|
||||
t_712 = self.n_Conv_153(t_711_padded)
|
||||
t_713 = F.relu(t_712)
|
||||
t_714 = self.n_Conv_154(t_713)
|
||||
t_715 = torch.add(t_714, t_709)
|
||||
t_716 = F.relu(t_715)
|
||||
t_717 = self.n_Conv_155(t_716)
|
||||
t_718 = F.relu(t_717)
|
||||
t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
|
||||
t_719 = self.n_Conv_156(t_718_padded)
|
||||
t_720 = F.relu(t_719)
|
||||
t_721 = self.n_Conv_157(t_720)
|
||||
t_722 = torch.add(t_721, t_716)
|
||||
t_723 = F.relu(t_722)
|
||||
t_724 = self.n_Conv_158(t_723)
|
||||
t_725 = self.n_Conv_159(t_723)
|
||||
t_726 = F.relu(t_725)
|
||||
t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
|
||||
t_727 = self.n_Conv_160(t_726_padded)
|
||||
t_728 = F.relu(t_727)
|
||||
t_729 = self.n_Conv_161(t_728)
|
||||
t_730 = torch.add(t_729, t_724)
|
||||
t_731 = F.relu(t_730)
|
||||
t_732 = self.n_Conv_162(t_731)
|
||||
t_733 = F.relu(t_732)
|
||||
t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
|
||||
t_734 = self.n_Conv_163(t_733_padded)
|
||||
t_735 = F.relu(t_734)
|
||||
t_736 = self.n_Conv_164(t_735)
|
||||
t_737 = torch.add(t_736, t_731)
|
||||
t_738 = F.relu(t_737)
|
||||
t_739 = self.n_Conv_165(t_738)
|
||||
t_740 = F.relu(t_739)
|
||||
t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
|
||||
t_741 = self.n_Conv_166(t_740_padded)
|
||||
t_742 = F.relu(t_741)
|
||||
t_743 = self.n_Conv_167(t_742)
|
||||
t_744 = torch.add(t_743, t_738)
|
||||
t_745 = F.relu(t_744)
|
||||
t_746 = self.n_Conv_168(t_745)
|
||||
t_747 = self.n_Conv_169(t_745)
|
||||
t_748 = F.relu(t_747)
|
||||
t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
|
||||
t_749 = self.n_Conv_170(t_748_padded)
|
||||
t_750 = F.relu(t_749)
|
||||
t_751 = self.n_Conv_171(t_750)
|
||||
t_752 = torch.add(t_751, t_746)
|
||||
t_753 = F.relu(t_752)
|
||||
t_754 = self.n_Conv_172(t_753)
|
||||
t_755 = F.relu(t_754)
|
||||
t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
|
||||
t_756 = self.n_Conv_173(t_755_padded)
|
||||
t_757 = F.relu(t_756)
|
||||
t_758 = self.n_Conv_174(t_757)
|
||||
t_759 = torch.add(t_758, t_753)
|
||||
t_760 = F.relu(t_759)
|
||||
t_761 = self.n_Conv_175(t_760)
|
||||
t_762 = F.relu(t_761)
|
||||
t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
|
||||
t_763 = self.n_Conv_176(t_762_padded)
|
||||
t_764 = F.relu(t_763)
|
||||
t_765 = self.n_Conv_177(t_764)
|
||||
t_766 = torch.add(t_765, t_760)
|
||||
t_767 = F.relu(t_766)
|
||||
t_768 = self.n_Conv_178(t_767)
|
||||
t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
|
||||
t_770 = torch.squeeze(t_769, 3)
|
||||
t_770 = torch.squeeze(t_770, 2)
|
||||
t_771 = torch.sigmoid(t_770)
|
||||
return t_771
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
self.tags = state_dict.get('tags', [])
|
||||
|
||||
super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
|
||||
|
||||
@ -1,80 +1,463 @@
|
||||
# this file is taken from https://github.com/xinntao/ESRGAN
|
||||
# this file is adapted from https://github.com/victorca25/iNNfer
|
||||
|
||||
import math
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def make_layer(block, n_layers):
|
||||
layers = []
|
||||
for _ in range(n_layers):
|
||||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
####################
|
||||
# RRDBNet Generator
|
||||
####################
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
|
||||
act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
|
||||
finalact=None, gaussian_noise=False, plus=False):
|
||||
super(RRDBNet, self).__init__()
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
if upscale == 3:
|
||||
n_upscale = 1
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
self.resrgan_scale = 0
|
||||
if in_nc % 16 == 0:
|
||||
self.resrgan_scale = 1
|
||||
elif in_nc != 4 and in_nc % 4 == 0:
|
||||
self.resrgan_scale = 2
|
||||
|
||||
# initialization
|
||||
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
||||
rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||
norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
|
||||
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
|
||||
LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
if upsample_mode == 'upconv':
|
||||
upsample_block = upconv_block
|
||||
elif upsample_mode == 'pixelshuffle':
|
||||
upsample_block = pixelshuffle_block
|
||||
else:
|
||||
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
||||
if upscale == 3:
|
||||
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
||||
else:
|
||||
upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
|
||||
HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
|
||||
HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
||||
|
||||
outact = act(finalact) if finalact else None
|
||||
|
||||
self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
|
||||
*upsampler, HR_conv0, HR_conv1, outact)
|
||||
|
||||
def forward(self, x, outm=None):
|
||||
if self.resrgan_scale == 1:
|
||||
feat = pixel_unshuffle(x, scale=4)
|
||||
elif self.resrgan_scale == 2:
|
||||
feat = pixel_unshuffle(x, scale=2)
|
||||
else:
|
||||
feat = x
|
||||
|
||||
return self.model(feat)
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
'''Residual in Residual Dense Block'''
|
||||
"""
|
||||
Residual in Residual Dense Block
|
||||
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
||||
"""
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False, gaussian_noise=False, plus=False):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
# This is for backwards compatibility with existing models
|
||||
if nr == 3:
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus)
|
||||
else:
|
||||
RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
|
||||
self.RDBs = nn.Sequential(*RDB_list)
|
||||
|
||||
def forward(self, x):
|
||||
if hasattr(self, 'RDB1'):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
else:
|
||||
out = self.RDBs(x)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
"""
|
||||
Residual Dense Block
|
||||
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
||||
Modified options that can be used:
|
||||
- "Partial Convolution based Padding" arXiv:1811.11718
|
||||
- "Spectral normalization" arXiv:1802.05957
|
||||
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
||||
{Rakotonirina} and A. {Rasoanaivo}
|
||||
"""
|
||||
|
||||
def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False, gaussian_noise=False, plus=False):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
|
||||
self.noise = GaussianNoise() if gaussian_noise else None
|
||||
self.conv1x1 = conv1x1(nf, gc) if plus else None
|
||||
|
||||
self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
if mode == 'CNA':
|
||||
last_act = None
|
||||
else:
|
||||
last_act = act_type
|
||||
self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.conv1(x)
|
||||
x2 = self.conv2(torch.cat((x, x1), 1))
|
||||
if self.conv1x1:
|
||||
x2 = x2 + self.conv1x1(x)
|
||||
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
||||
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
||||
if self.conv1x1:
|
||||
x4 = x4 + x2
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
if self.noise:
|
||||
return self.noise(x5.mul(0.2) + x)
|
||||
else:
|
||||
return x5 * 0.2 + x
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
####################
|
||||
# ESRGANplus
|
||||
####################
|
||||
|
||||
class GaussianNoise(nn.Module):
|
||||
def __init__(self, sigma=0.1, is_relative_detach=False):
|
||||
super().__init__()
|
||||
self.sigma = sigma
|
||||
self.is_relative_detach = is_relative_detach
|
||||
self.noise = torch.tensor(0, dtype=torch.float)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
fea = fea + trunk
|
||||
if self.training and self.sigma != 0:
|
||||
self.noise = self.noise.to(x.device)
|
||||
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
||||
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
||||
x = x + sampled_noise
|
||||
return x
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
####################
|
||||
# SRVGGNetCompact
|
||||
####################
|
||||
|
||||
class SRVGGNetCompact(nn.Module):
|
||||
"""A compact VGG-style network structure for super-resolution.
|
||||
This class is copied from https://github.com/xinntao/Real-ESRGAN
|
||||
"""
|
||||
|
||||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
||||
super(SRVGGNetCompact, self).__init__()
|
||||
self.num_in_ch = num_in_ch
|
||||
self.num_out_ch = num_out_ch
|
||||
self.num_feat = num_feat
|
||||
self.num_conv = num_conv
|
||||
self.upscale = upscale
|
||||
self.act_type = act_type
|
||||
|
||||
self.body = nn.ModuleList()
|
||||
# the first conv
|
||||
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
||||
# the first activation
|
||||
if act_type == 'relu':
|
||||
activation = nn.ReLU(inplace=True)
|
||||
elif act_type == 'prelu':
|
||||
activation = nn.PReLU(num_parameters=num_feat)
|
||||
elif act_type == 'leakyrelu':
|
||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.body.append(activation)
|
||||
|
||||
# the body structure
|
||||
for _ in range(num_conv):
|
||||
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
||||
# activation
|
||||
if act_type == 'relu':
|
||||
activation = nn.ReLU(inplace=True)
|
||||
elif act_type == 'prelu':
|
||||
activation = nn.PReLU(num_parameters=num_feat)
|
||||
elif act_type == 'leakyrelu':
|
||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.body.append(activation)
|
||||
|
||||
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
# the last conv
|
||||
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
||||
# upsample
|
||||
self.upsampler = nn.PixelShuffle(upscale)
|
||||
|
||||
def forward(self, x):
|
||||
out = x
|
||||
for i in range(0, len(self.body)):
|
||||
out = self.body[i](out)
|
||||
|
||||
out = self.upsampler(out)
|
||||
# add the nearest upsampled image, so that the network learns the residual
|
||||
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
||||
out += base
|
||||
return out
|
||||
|
||||
|
||||
####################
|
||||
# Upsampler
|
||||
####################
|
||||
|
||||
class Upsample(nn.Module):
|
||||
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
|
||||
The input data is assumed to be of the form
|
||||
`minibatch x channels x [optional depth] x [optional height] x width`.
|
||||
"""
|
||||
|
||||
def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||
super(Upsample, self).__init__()
|
||||
if isinstance(scale_factor, tuple):
|
||||
self.scale_factor = tuple(float(factor) for factor in scale_factor)
|
||||
else:
|
||||
self.scale_factor = float(scale_factor) if scale_factor else None
|
||||
self.mode = mode
|
||||
self.size = size
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
||||
|
||||
def extra_repr(self):
|
||||
if self.scale_factor is not None:
|
||||
info = 'scale_factor=' + str(self.scale_factor)
|
||||
else:
|
||||
info = 'size=' + str(self.size)
|
||||
info += ', mode=' + self.mode
|
||||
return info
|
||||
|
||||
|
||||
def pixel_unshuffle(x, scale):
|
||||
""" Pixel unshuffle.
|
||||
Args:
|
||||
x (Tensor): Input feature with shape (b, c, hh, hw).
|
||||
scale (int): Downsample ratio.
|
||||
Returns:
|
||||
Tensor: the pixel unshuffled feature.
|
||||
"""
|
||||
b, c, hh, hw = x.size()
|
||||
out_channel = c * (scale**2)
|
||||
assert hh % scale == 0 and hw % scale == 0
|
||||
h = hh // scale
|
||||
w = hw // scale
|
||||
x_view = x.view(b, c, h, scale, w, scale)
|
||||
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
||||
|
||||
|
||||
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
||||
pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
|
||||
"""
|
||||
Pixel shuffle layer
|
||||
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
||||
Neural Network, CVPR17)
|
||||
"""
|
||||
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
|
||||
pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
|
||||
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
a = act(act_type) if act_type else None
|
||||
return sequential(conv, pixel_shuffle, n, a)
|
||||
|
||||
|
||||
def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
||||
pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
|
||||
""" Upconv layer """
|
||||
upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
|
||||
upsample = Upsample(scale_factor=upscale_factor, mode=mode)
|
||||
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
|
||||
pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
|
||||
return sequential(upsample, conv)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
####################
|
||||
# Basic blocks
|
||||
####################
|
||||
|
||||
|
||||
def make_layer(basic_block, num_basic_block, **kwarg):
|
||||
"""Make layers by stacking the same blocks.
|
||||
Args:
|
||||
basic_block (nn.module): nn.module class for basic block. (block)
|
||||
num_basic_block (int): number of blocks. (n_layers)
|
||||
Returns:
|
||||
nn.Sequential: Stacked blocks in nn.Sequential.
|
||||
"""
|
||||
layers = []
|
||||
for _ in range(num_basic_block):
|
||||
layers.append(basic_block(**kwarg))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
|
||||
""" activation helper """
|
||||
act_type = act_type.lower()
|
||||
if act_type == 'relu':
|
||||
layer = nn.ReLU(inplace)
|
||||
elif act_type in ('leakyrelu', 'lrelu'):
|
||||
layer = nn.LeakyReLU(neg_slope, inplace)
|
||||
elif act_type == 'prelu':
|
||||
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
||||
elif act_type == 'tanh': # [-1, 1] range output
|
||||
layer = nn.Tanh()
|
||||
elif act_type == 'sigmoid': # [0, 1] range output
|
||||
layer = nn.Sigmoid()
|
||||
else:
|
||||
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
|
||||
return layer
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self, *kwargs):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x, *kwargs):
|
||||
return x
|
||||
|
||||
|
||||
def norm(norm_type, nc):
|
||||
""" Return a normalization layer """
|
||||
norm_type = norm_type.lower()
|
||||
if norm_type == 'batch':
|
||||
layer = nn.BatchNorm2d(nc, affine=True)
|
||||
elif norm_type == 'instance':
|
||||
layer = nn.InstanceNorm2d(nc, affine=False)
|
||||
elif norm_type == 'none':
|
||||
def norm_layer(x): return Identity()
|
||||
else:
|
||||
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
|
||||
return layer
|
||||
|
||||
|
||||
def pad(pad_type, padding):
|
||||
""" padding layer helper """
|
||||
pad_type = pad_type.lower()
|
||||
if padding == 0:
|
||||
return None
|
||||
if pad_type == 'reflect':
|
||||
layer = nn.ReflectionPad2d(padding)
|
||||
elif pad_type == 'replicate':
|
||||
layer = nn.ReplicationPad2d(padding)
|
||||
elif pad_type == 'zero':
|
||||
layer = nn.ZeroPad2d(padding)
|
||||
else:
|
||||
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
|
||||
return layer
|
||||
|
||||
|
||||
def get_valid_padding(kernel_size, dilation):
|
||||
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
||||
padding = (kernel_size - 1) // 2
|
||||
return padding
|
||||
|
||||
|
||||
class ShortcutBlock(nn.Module):
|
||||
""" Elementwise sum the output of a submodule to its input """
|
||||
def __init__(self, submodule):
|
||||
super(ShortcutBlock, self).__init__()
|
||||
self.sub = submodule
|
||||
|
||||
def forward(self, x):
|
||||
output = x + self.sub(x)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
|
||||
|
||||
|
||||
def sequential(*args):
|
||||
""" Flatten Sequential. It unwraps nn.Sequential. """
|
||||
if len(args) == 1:
|
||||
if isinstance(args[0], OrderedDict):
|
||||
raise NotImplementedError('sequential does not support OrderedDict input.')
|
||||
return args[0] # No sequential is needed.
|
||||
modules = []
|
||||
for module in args:
|
||||
if isinstance(module, nn.Sequential):
|
||||
for submodule in module.children():
|
||||
modules.append(submodule)
|
||||
elif isinstance(module, nn.Module):
|
||||
modules.append(module)
|
||||
return nn.Sequential(*modules)
|
||||
|
||||
|
||||
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
|
||||
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False):
|
||||
""" Conv layer with padding, normalization, activation """
|
||||
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
|
||||
padding = get_valid_padding(kernel_size, dilation)
|
||||
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
||||
padding = padding if pad_type == 'zero' else 0
|
||||
|
||||
if convtype=='PartialConv2D':
|
||||
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
elif convtype=='DeformConv2D':
|
||||
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
elif convtype=='Conv3D':
|
||||
c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
else:
|
||||
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
|
||||
if spectral_norm:
|
||||
c = nn.utils.spectral_norm(c)
|
||||
|
||||
a = act(act_type) if act_type else None
|
||||
if 'CNA' in mode:
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
return sequential(p, c, n, a)
|
||||
elif mode == 'NAC':
|
||||
if norm_type is None and act_type is not None:
|
||||
a = act(act_type, inplace=False)
|
||||
n = norm(norm_type, in_nc) if norm_type else None
|
||||
return sequential(n, a, p, c)
|
||||
|
||||
@ -0,0 +1,99 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import git
|
||||
|
||||
from modules import paths, shared
|
||||
|
||||
extensions = []
|
||||
extensions_dir = os.path.join(paths.script_path, "extensions")
|
||||
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
|
||||
|
||||
|
||||
def active():
|
||||
return [x for x in extensions if x.enabled]
|
||||
|
||||
|
||||
class Extension:
|
||||
def __init__(self, name, path, enabled=True, is_builtin=False):
|
||||
self.name = name
|
||||
self.path = path
|
||||
self.enabled = enabled
|
||||
self.status = ''
|
||||
self.can_update = False
|
||||
self.is_builtin = is_builtin
|
||||
|
||||
repo = None
|
||||
try:
|
||||
if os.path.exists(os.path.join(path, ".git")):
|
||||
repo = git.Repo(path)
|
||||
except Exception:
|
||||
print(f"Error reading github repository info from {path}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if repo is None or repo.bare:
|
||||
self.remote = None
|
||||
else:
|
||||
try:
|
||||
self.remote = next(repo.remote().urls, None)
|
||||
self.status = 'unknown'
|
||||
except Exception:
|
||||
self.remote = None
|
||||
|
||||
def list_files(self, subdir, extension):
|
||||
from modules import scripts
|
||||
|
||||
dirpath = os.path.join(self.path, subdir)
|
||||
if not os.path.isdir(dirpath):
|
||||
return []
|
||||
|
||||
res = []
|
||||
for filename in sorted(os.listdir(dirpath)):
|
||||
res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
|
||||
|
||||
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||
|
||||
return res
|
||||
|
||||
def check_updates(self):
|
||||
repo = git.Repo(self.path)
|
||||
for fetch in repo.remote().fetch("--dry-run"):
|
||||
if fetch.flags != fetch.HEAD_UPTODATE:
|
||||
self.can_update = True
|
||||
self.status = "behind"
|
||||
return
|
||||
|
||||
self.can_update = False
|
||||
self.status = "latest"
|
||||
|
||||
def fetch_and_reset_hard(self):
|
||||
repo = git.Repo(self.path)
|
||||
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
||||
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
||||
repo.git.fetch('--all')
|
||||
repo.git.reset('--hard', 'origin')
|
||||
|
||||
|
||||
def list_extensions():
|
||||
extensions.clear()
|
||||
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
paths = []
|
||||
for dirname in [extensions_dir, extensions_builtin_dir]:
|
||||
if not os.path.isdir(dirname):
|
||||
return
|
||||
|
||||
for extension_dirname in sorted(os.listdir(dirname)):
|
||||
path = os.path.join(dirname, extension_dirname)
|
||||
if not os.path.isdir(path):
|
||||
continue
|
||||
|
||||
paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
|
||||
|
||||
for dirname, path, is_builtin in paths:
|
||||
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
|
||||
extensions.append(extension)
|
||||
|
||||
@ -1,183 +0,0 @@
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
def traverse_all_files(output_dir, image_list, curr_dir=None):
|
||||
curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
|
||||
try:
|
||||
f_list = os.listdir(curr_path)
|
||||
except:
|
||||
if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt":
|
||||
image_list.append(curr_dir)
|
||||
return image_list
|
||||
for file in f_list:
|
||||
file = file if curr_dir is None else os.path.join(curr_dir, file)
|
||||
file_path = os.path.join(curr_path, file)
|
||||
if file[-4:] == ".txt":
|
||||
pass
|
||||
elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0:
|
||||
image_list.append(file)
|
||||
else:
|
||||
image_list = traverse_all_files(output_dir, image_list, file)
|
||||
return image_list
|
||||
|
||||
|
||||
def get_recent_images(dir_name, page_index, step, image_index, tabname):
|
||||
page_index = int(page_index)
|
||||
image_list = []
|
||||
if not os.path.exists(dir_name):
|
||||
pass
|
||||
elif os.path.isdir(dir_name):
|
||||
image_list = traverse_all_files(dir_name, image_list)
|
||||
image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file)))
|
||||
else:
|
||||
print(f'ERROR: "{dir_name}" is not a directory. Check the path in the settings.', file=sys.stderr)
|
||||
num = 48 if tabname != "extras" else 12
|
||||
max_page_index = len(image_list) // num + 1
|
||||
page_index = max_page_index if page_index == -1 else page_index + step
|
||||
page_index = 1 if page_index < 1 else page_index
|
||||
page_index = max_page_index if page_index > max_page_index else page_index
|
||||
idx_frm = (page_index - 1) * num
|
||||
image_list = image_list[idx_frm:idx_frm + num]
|
||||
image_index = int(image_index)
|
||||
if image_index < 0 or image_index > len(image_list) - 1:
|
||||
current_file = None
|
||||
hidden = None
|
||||
else:
|
||||
current_file = image_list[int(image_index)]
|
||||
hidden = os.path.join(dir_name, current_file)
|
||||
return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, ""
|
||||
|
||||
|
||||
def first_page_click(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, 1, 0, image_index, tabname)
|
||||
|
||||
|
||||
def end_page_click(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, -1, 0, image_index, tabname)
|
||||
|
||||
|
||||
def prev_page_click(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, page_index, -1, image_index, tabname)
|
||||
|
||||
|
||||
def next_page_click(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, page_index, 1, image_index, tabname)
|
||||
|
||||
|
||||
def page_index_change(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, page_index, 0, image_index, tabname)
|
||||
|
||||
|
||||
def show_image_info(num, image_path, filenames):
|
||||
# print(f"select image {num}")
|
||||
file = filenames[int(num)]
|
||||
return file, num, os.path.join(image_path, file)
|
||||
|
||||
|
||||
def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index):
|
||||
if name == "":
|
||||
return filenames, delete_num
|
||||
else:
|
||||
delete_num = int(delete_num)
|
||||
index = list(filenames).index(name)
|
||||
i = 0
|
||||
new_file_list = []
|
||||
for name in filenames:
|
||||
if i >= index and i < index + delete_num:
|
||||
path = os.path.join(dir_name, name)
|
||||
if os.path.exists(path):
|
||||
print(f"Delete file {path}")
|
||||
os.remove(path)
|
||||
txt_file = os.path.splitext(path)[0] + ".txt"
|
||||
if os.path.exists(txt_file):
|
||||
os.remove(txt_file)
|
||||
else:
|
||||
print(f"Not exists file {path}")
|
||||
else:
|
||||
new_file_list.append(name)
|
||||
i += 1
|
||||
return new_file_list, 1
|
||||
|
||||
|
||||
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
|
||||
if opts.outdir_samples != "":
|
||||
dir_name = opts.outdir_samples
|
||||
elif tabname == "txt2img":
|
||||
dir_name = opts.outdir_txt2img_samples
|
||||
elif tabname == "img2img":
|
||||
dir_name = opts.outdir_img2img_samples
|
||||
elif tabname == "extras":
|
||||
dir_name = opts.outdir_extras_samples
|
||||
else:
|
||||
return
|
||||
with gr.Row():
|
||||
renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page")
|
||||
first_page = gr.Button('First Page')
|
||||
prev_page = gr.Button('Prev Page')
|
||||
page_index = gr.Number(value=1, label="Page Index")
|
||||
next_page = gr.Button('Next Page')
|
||||
end_page = gr.Button('End Page')
|
||||
with gr.Row(elem_id=tabname + "_images_history"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
|
||||
with gr.Row():
|
||||
delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
|
||||
delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
|
||||
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
img_file_info = gr.Textbox(label="Generate Info", interactive=False)
|
||||
img_file_name = gr.Textbox(label="File Name", interactive=False)
|
||||
with gr.Row():
|
||||
# hiden items
|
||||
|
||||
img_path = gr.Textbox(dir_name.rstrip("/"), visible=False)
|
||||
tabname_box = gr.Textbox(tabname, visible=False)
|
||||
image_index = gr.Textbox(value=-1, visible=False)
|
||||
set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False)
|
||||
filenames = gr.State()
|
||||
hidden = gr.Image(type="pil", visible=False)
|
||||
info1 = gr.Textbox(visible=False)
|
||||
info2 = gr.Textbox(visible=False)
|
||||
|
||||
# turn pages
|
||||
gallery_inputs = [img_path, page_index, image_index, tabname_box]
|
||||
gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name]
|
||||
|
||||
first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
# page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index])
|
||||
|
||||
# other funcitons
|
||||
set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden])
|
||||
img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
|
||||
delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
|
||||
hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
|
||||
|
||||
# pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
|
||||
switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
|
||||
switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
|
||||
|
||||
|
||||
def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
|
||||
with gr.Blocks(analytics_enabled=False) as images_history:
|
||||
with gr.Tabs() as tabs:
|
||||
with gr.Tab("txt2img history"):
|
||||
with gr.Blocks(analytics_enabled=False) as images_history_txt2img:
|
||||
show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict)
|
||||
with gr.Tab("img2img history"):
|
||||
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
|
||||
show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict)
|
||||
with gr.Tab("extras history"):
|
||||
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
|
||||
show_images_history(gr, opts, "extras", run_pnginfo, switch_dict)
|
||||
return images_history
|
||||
@ -0,0 +1,5 @@
|
||||
import sys
|
||||
|
||||
# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
|
||||
if "--xformers" not in "".join(sys.argv):
|
||||
sys.modules["xformers"] = None
|
||||
@ -1,42 +0,0 @@
|
||||
import torch
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
from PIL import Image
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_feature_extractor = None
|
||||
safety_checker = None
|
||||
|
||||
def numpy_to_pil(images):
|
||||
"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
images = (images * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
|
||||
return pil_images
|
||||
|
||||
# check and replace nsfw content
|
||||
def check_safety(x_image):
|
||||
global safety_feature_extractor, safety_checker
|
||||
|
||||
if safety_feature_extractor is None:
|
||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
||||
|
||||
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
||||
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
||||
|
||||
return x_checked_image, has_nsfw_concept
|
||||
|
||||
|
||||
def censor_batch(x):
|
||||
x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
|
||||
x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
|
||||
x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
|
||||
|
||||
return x
|
||||
@ -0,0 +1,315 @@
|
||||
import sys
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
import inspect
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
from gradio import Blocks
|
||||
|
||||
|
||||
def report_exception(c, job):
|
||||
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
class ImageSaveParams:
|
||||
def __init__(self, image, p, filename, pnginfo):
|
||||
self.image = image
|
||||
"""the PIL image itself"""
|
||||
|
||||
self.p = p
|
||||
"""p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
|
||||
|
||||
self.filename = filename
|
||||
"""name of file that the image would be saved to"""
|
||||
|
||||
self.pnginfo = pnginfo
|
||||
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
|
||||
|
||||
|
||||
class CFGDenoiserParams:
|
||||
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
|
||||
self.x = x
|
||||
"""Latent image representation in the process of being denoised"""
|
||||
|
||||
self.image_cond = image_cond
|
||||
"""Conditioning image"""
|
||||
|
||||
self.sigma = sigma
|
||||
"""Current sigma noise step value"""
|
||||
|
||||
self.sampling_step = sampling_step
|
||||
"""Current Sampling step number"""
|
||||
|
||||
self.total_sampling_steps = total_sampling_steps
|
||||
"""Total number of sampling steps planned"""
|
||||
|
||||
|
||||
class UiTrainTabParams:
|
||||
def __init__(self, txt2img_preview_params):
|
||||
self.txt2img_preview_params = txt2img_preview_params
|
||||
|
||||
|
||||
class ImageGridLoopParams:
|
||||
def __init__(self, imgs, cols, rows):
|
||||
self.imgs = imgs
|
||||
self.cols = cols
|
||||
self.rows = rows
|
||||
|
||||
|
||||
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
||||
callback_map = dict(
|
||||
callbacks_app_started=[],
|
||||
callbacks_model_loaded=[],
|
||||
callbacks_ui_tabs=[],
|
||||
callbacks_ui_train_tabs=[],
|
||||
callbacks_ui_settings=[],
|
||||
callbacks_before_image_saved=[],
|
||||
callbacks_image_saved=[],
|
||||
callbacks_cfg_denoiser=[],
|
||||
callbacks_before_component=[],
|
||||
callbacks_after_component=[],
|
||||
callbacks_image_grid=[],
|
||||
callbacks_infotext_pasted=[],
|
||||
callbacks_script_unloaded=[],
|
||||
)
|
||||
|
||||
|
||||
def clear_callbacks():
|
||||
for callback_list in callback_map.values():
|
||||
callback_list.clear()
|
||||
|
||||
|
||||
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
||||
for c in callback_map['callbacks_app_started']:
|
||||
try:
|
||||
c.callback(demo, app)
|
||||
except Exception:
|
||||
report_exception(c, 'app_started_callback')
|
||||
|
||||
|
||||
def model_loaded_callback(sd_model):
|
||||
for c in callback_map['callbacks_model_loaded']:
|
||||
try:
|
||||
c.callback(sd_model)
|
||||
except Exception:
|
||||
report_exception(c, 'model_loaded_callback')
|
||||
|
||||
|
||||
def ui_tabs_callback():
|
||||
res = []
|
||||
|
||||
for c in callback_map['callbacks_ui_tabs']:
|
||||
try:
|
||||
res += c.callback() or []
|
||||
except Exception:
|
||||
report_exception(c, 'ui_tabs_callback')
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def ui_train_tabs_callback(params: UiTrainTabParams):
|
||||
for c in callback_map['callbacks_ui_train_tabs']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'callbacks_ui_train_tabs')
|
||||
|
||||
|
||||
def ui_settings_callback():
|
||||
for c in callback_map['callbacks_ui_settings']:
|
||||
try:
|
||||
c.callback()
|
||||
except Exception:
|
||||
report_exception(c, 'ui_settings_callback')
|
||||
|
||||
|
||||
def before_image_saved_callback(params: ImageSaveParams):
|
||||
for c in callback_map['callbacks_before_image_saved']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'before_image_saved_callback')
|
||||
|
||||
|
||||
def image_saved_callback(params: ImageSaveParams):
|
||||
for c in callback_map['callbacks_image_saved']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'image_saved_callback')
|
||||
|
||||
|
||||
def cfg_denoiser_callback(params: CFGDenoiserParams):
|
||||
for c in callback_map['callbacks_cfg_denoiser']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'cfg_denoiser_callback')
|
||||
|
||||
|
||||
def before_component_callback(component, **kwargs):
|
||||
for c in callback_map['callbacks_before_component']:
|
||||
try:
|
||||
c.callback(component, **kwargs)
|
||||
except Exception:
|
||||
report_exception(c, 'before_component_callback')
|
||||
|
||||
|
||||
def after_component_callback(component, **kwargs):
|
||||
for c in callback_map['callbacks_after_component']:
|
||||
try:
|
||||
c.callback(component, **kwargs)
|
||||
except Exception:
|
||||
report_exception(c, 'after_component_callback')
|
||||
|
||||
|
||||
def image_grid_callback(params: ImageGridLoopParams):
|
||||
for c in callback_map['callbacks_image_grid']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'image_grid')
|
||||
|
||||
|
||||
def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
|
||||
for c in callback_map['callbacks_infotext_pasted']:
|
||||
try:
|
||||
c.callback(infotext, params)
|
||||
except Exception:
|
||||
report_exception(c, 'infotext_pasted')
|
||||
|
||||
|
||||
def script_unloaded_callback():
|
||||
for c in reversed(callback_map['callbacks_script_unloaded']):
|
||||
try:
|
||||
c.callback()
|
||||
except Exception:
|
||||
report_exception(c, 'script_unloaded')
|
||||
|
||||
|
||||
def add_callback(callbacks, fun):
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||
|
||||
callbacks.append(ScriptCallback(filename, fun))
|
||||
|
||||
|
||||
def remove_current_script_callbacks():
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||
if filename == 'unknown file':
|
||||
return
|
||||
for callback_list in callback_map.values():
|
||||
for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
|
||||
callback_list.remove(callback_to_remove)
|
||||
|
||||
|
||||
def remove_callbacks_for_function(callback_func):
|
||||
for callback_list in callback_map.values():
|
||||
for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
|
||||
callback_list.remove(callback_to_remove)
|
||||
|
||||
|
||||
def on_app_started(callback):
|
||||
"""register a function to be called when the webui started, the gradio `Block` component and
|
||||
fastapi `FastAPI` object are passed as the arguments"""
|
||||
add_callback(callback_map['callbacks_app_started'], callback)
|
||||
|
||||
|
||||
def on_model_loaded(callback):
|
||||
"""register a function to be called when the stable diffusion model is created; the model is
|
||||
passed as an argument; this function is also called when the script is reloaded. """
|
||||
add_callback(callback_map['callbacks_model_loaded'], callback)
|
||||
|
||||
|
||||
def on_ui_tabs(callback):
|
||||
"""register a function to be called when the UI is creating new tabs.
|
||||
The function must either return a None, which means no new tabs to be added, or a list, where
|
||||
each element is a tuple:
|
||||
(gradio_component, title, elem_id)
|
||||
|
||||
gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
|
||||
title is tab text displayed to user in the UI
|
||||
elem_id is HTML id for the tab
|
||||
"""
|
||||
add_callback(callback_map['callbacks_ui_tabs'], callback)
|
||||
|
||||
|
||||
def on_ui_train_tabs(callback):
|
||||
"""register a function to be called when the UI is creating new tabs for the train tab.
|
||||
Create your new tabs with gr.Tab.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_ui_train_tabs'], callback)
|
||||
|
||||
|
||||
def on_ui_settings(callback):
|
||||
"""register a function to be called before UI settings are populated; add your settings
|
||||
by using shared.opts.add_option(shared.OptionInfo(...)) """
|
||||
add_callback(callback_map['callbacks_ui_settings'], callback)
|
||||
|
||||
|
||||
def on_before_image_saved(callback):
|
||||
"""register a function to be called before an image is saved to a file.
|
||||
The callback is called with one argument:
|
||||
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_before_image_saved'], callback)
|
||||
|
||||
|
||||
def on_image_saved(callback):
|
||||
"""register a function to be called after an image is saved to a file.
|
||||
The callback is called with one argument:
|
||||
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_image_saved'], callback)
|
||||
|
||||
|
||||
def on_cfg_denoiser(callback):
|
||||
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
|
||||
The callback is called with one argument:
|
||||
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_cfg_denoiser'], callback)
|
||||
|
||||
|
||||
def on_before_component(callback):
|
||||
"""register a function to be called before a component is created.
|
||||
The callback is called with arguments:
|
||||
- component - gradio component that is about to be created.
|
||||
- **kwargs - args to gradio.components.IOComponent.__init__ function
|
||||
|
||||
Use elem_id/label fields of kwargs to figure out which component it is.
|
||||
This can be useful to inject your own components somewhere in the middle of vanilla UI.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_before_component'], callback)
|
||||
|
||||
|
||||
def on_after_component(callback):
|
||||
"""register a function to be called after a component is created. See on_before_component for more."""
|
||||
add_callback(callback_map['callbacks_after_component'], callback)
|
||||
|
||||
|
||||
def on_image_grid(callback):
|
||||
"""register a function to be called before making an image grid.
|
||||
The callback is called with one argument:
|
||||
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_image_grid'], callback)
|
||||
|
||||
|
||||
def on_infotext_pasted(callback):
|
||||
"""register a function to be called before applying an infotext.
|
||||
The callback is called with two arguments:
|
||||
- infotext: str - raw infotext.
|
||||
- result: Dict[str, any] - parsed infotext parameters.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_infotext_pasted'], callback)
|
||||
|
||||
|
||||
def on_script_unloaded(callback):
|
||||
"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
|
||||
the script did should be reverted here"""
|
||||
|
||||
add_callback(callback_map['callbacks_script_unloaded'], callback)
|
||||
@ -0,0 +1,34 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
def load_module(path):
|
||||
with open(path, "r", encoding="utf8") as file:
|
||||
text = file.read()
|
||||
|
||||
compiled = compile(text, path, 'exec')
|
||||
module = ModuleType(os.path.basename(path))
|
||||
exec(compiled, module.__dict__)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def preload_extensions(extensions_dir, parser):
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
for dirname in sorted(os.listdir(extensions_dir)):
|
||||
preload_script = os.path.join(extensions_dir, dirname, "preload.py")
|
||||
if not os.path.isfile(preload_script):
|
||||
continue
|
||||
|
||||
try:
|
||||
module = load_module(preload_script)
|
||||
if hasattr(module, 'preload'):
|
||||
module.preload(parser)
|
||||
|
||||
except Exception:
|
||||
print(f"Error running preload() for {preload_script}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
@ -0,0 +1,88 @@
|
||||
import ldm.modules.encoders.modules
|
||||
import open_clip
|
||||
import torch
|
||||
import transformers.utils.hub
|
||||
|
||||
|
||||
class DisableInitialization:
|
||||
"""
|
||||
When an object of this class enters a `with` block, it starts:
|
||||
- preventing torch's layer initialization functions from working
|
||||
- changes CLIP and OpenCLIP to not download model weights
|
||||
- changes CLIP to not make requests to check if there is a new version of a file you already have
|
||||
|
||||
When it leaves the block, it reverts everything to how it was before.
|
||||
|
||||
Use it like this:
|
||||
```
|
||||
with DisableInitialization():
|
||||
do_things()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.replaced = []
|
||||
|
||||
def replace(self, obj, field, func):
|
||||
original = getattr(obj, field, None)
|
||||
if original is None:
|
||||
return None
|
||||
|
||||
self.replaced.append((obj, field, original))
|
||||
setattr(obj, field, func)
|
||||
|
||||
return original
|
||||
|
||||
def __enter__(self):
|
||||
def do_nothing(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
|
||||
return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
|
||||
|
||||
def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
|
||||
|
||||
def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
|
||||
args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
|
||||
return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
|
||||
|
||||
def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
|
||||
|
||||
# this file is always 404, prevent making request
|
||||
if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
|
||||
return None
|
||||
|
||||
try:
|
||||
res = original(url, *args, local_files_only=True, **kwargs)
|
||||
if res is None:
|
||||
res = original(url, *args, local_files_only=False, **kwargs)
|
||||
return res
|
||||
except Exception as e:
|
||||
return original(url, *args, local_files_only=False, **kwargs)
|
||||
|
||||
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
||||
return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
|
||||
|
||||
def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
|
||||
return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
|
||||
|
||||
def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
|
||||
return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
|
||||
|
||||
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
|
||||
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
|
||||
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
|
||||
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
||||
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
||||
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
||||
self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
||||
self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
||||
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for obj, field, original in self.replaced:
|
||||
setattr(obj, field, original)
|
||||
|
||||
self.replaced.clear()
|
||||
|
||||
@ -0,0 +1,10 @@
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
def BasicTransformerBlock_forward(self, x, context=None):
|
||||
return checkpoint(self._forward, x, context)
|
||||
|
||||
def AttentionBlock_forward(self, x):
|
||||
return checkpoint(self._forward, x)
|
||||
|
||||
def ResBlock_forward(self, x, emb):
|
||||
return checkpoint(self._forward, x, emb)
|
||||
@ -0,0 +1,308 @@
|
||||
import math
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
|
||||
from modules import prompt_parser, devices, sd_hijack
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
class PromptChunk:
|
||||
"""
|
||||
This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
|
||||
If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
|
||||
Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
|
||||
so just 75 tokens from prompt.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.tokens = []
|
||||
self.multipliers = []
|
||||
self.fixes = []
|
||||
|
||||
|
||||
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
||||
"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
|
||||
chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
|
||||
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
"""A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
|
||||
have unlimited prompt length and assign weights to tokens in prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__()
|
||||
|
||||
self.wrapped = wrapped
|
||||
"""Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
|
||||
depending on model."""
|
||||
|
||||
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
|
||||
self.chunk_length = 75
|
||||
|
||||
def empty_chunk(self):
|
||||
"""creates an empty PromptChunk and returns it"""
|
||||
|
||||
chunk = PromptChunk()
|
||||
chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
|
||||
chunk.multipliers = [1.0] * (self.chunk_length + 2)
|
||||
return chunk
|
||||
|
||||
def get_target_prompt_token_count(self, token_count):
|
||||
"""returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
|
||||
|
||||
return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
|
||||
|
||||
def tokenize(self, texts):
|
||||
"""Converts a batch of texts into a batch of token ids"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
"""
|
||||
converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
|
||||
All python lists with tokens are assumed to have same length, usually 77.
|
||||
if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
|
||||
model - can be 768 and 1024.
|
||||
Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
"""Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
|
||||
transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize_line(self, line):
|
||||
"""
|
||||
this transforms a single prompt into a list of PromptChunk objects - as many as needed to
|
||||
represent the prompt.
|
||||
Returns the list and the total number of tokens in the prompt.
|
||||
"""
|
||||
|
||||
if opts.enable_emphasis:
|
||||
parsed = prompt_parser.parse_prompt_attention(line)
|
||||
else:
|
||||
parsed = [[line, 1.0]]
|
||||
|
||||
tokenized = self.tokenize([text for text, _ in parsed])
|
||||
|
||||
chunks = []
|
||||
chunk = PromptChunk()
|
||||
token_count = 0
|
||||
last_comma = -1
|
||||
|
||||
def next_chunk():
|
||||
"""puts current chunk into the list of results and produces the next one - empty"""
|
||||
nonlocal token_count
|
||||
nonlocal last_comma
|
||||
nonlocal chunk
|
||||
|
||||
token_count += len(chunk.tokens)
|
||||
to_add = self.chunk_length - len(chunk.tokens)
|
||||
if to_add > 0:
|
||||
chunk.tokens += [self.id_end] * to_add
|
||||
chunk.multipliers += [1.0] * to_add
|
||||
|
||||
chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
|
||||
chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
|
||||
|
||||
last_comma = -1
|
||||
chunks.append(chunk)
|
||||
chunk = PromptChunk()
|
||||
|
||||
for tokens, (text, weight) in zip(tokenized, parsed):
|
||||
position = 0
|
||||
while position < len(tokens):
|
||||
token = tokens[position]
|
||||
|
||||
if token == self.comma_token:
|
||||
last_comma = len(chunk.tokens)
|
||||
|
||||
# this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
|
||||
# is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
|
||||
elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
|
||||
break_location = last_comma + 1
|
||||
|
||||
reloc_tokens = chunk.tokens[break_location:]
|
||||
reloc_mults = chunk.multipliers[break_location:]
|
||||
|
||||
chunk.tokens = chunk.tokens[:break_location]
|
||||
chunk.multipliers = chunk.multipliers[:break_location]
|
||||
|
||||
next_chunk()
|
||||
chunk.tokens = reloc_tokens
|
||||
chunk.multipliers = reloc_mults
|
||||
|
||||
if len(chunk.tokens) == self.chunk_length:
|
||||
next_chunk()
|
||||
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
|
||||
if embedding is None:
|
||||
chunk.tokens.append(token)
|
||||
chunk.multipliers.append(weight)
|
||||
position += 1
|
||||
continue
|
||||
|
||||
emb_len = int(embedding.vec.shape[0])
|
||||
if len(chunk.tokens) + emb_len > self.chunk_length:
|
||||
next_chunk()
|
||||
|
||||
chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
|
||||
|
||||
chunk.tokens += [0] * emb_len
|
||||
chunk.multipliers += [weight] * emb_len
|
||||
position += embedding_length_in_tokens
|
||||
|
||||
if len(chunk.tokens) > 0 or len(chunks) == 0:
|
||||
next_chunk()
|
||||
|
||||
return chunks, token_count
|
||||
|
||||
def process_texts(self, texts):
|
||||
"""
|
||||
Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
|
||||
length, in tokens, of all texts.
|
||||
"""
|
||||
|
||||
token_count = 0
|
||||
|
||||
cache = {}
|
||||
batch_chunks = []
|
||||
for line in texts:
|
||||
if line in cache:
|
||||
chunks = cache[line]
|
||||
else:
|
||||
chunks, current_token_count = self.tokenize_line(line)
|
||||
token_count = max(current_token_count, token_count)
|
||||
|
||||
cache[line] = chunks
|
||||
|
||||
batch_chunks.append(chunks)
|
||||
|
||||
return batch_chunks, token_count
|
||||
|
||||
def forward(self, texts):
|
||||
"""
|
||||
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
|
||||
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
|
||||
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
|
||||
An example shape returned by this function can be: (2, 77, 768).
|
||||
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
|
||||
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
|
||||
"""
|
||||
|
||||
if opts.use_old_emphasis_implementation:
|
||||
import modules.sd_hijack_clip_old
|
||||
return modules.sd_hijack_clip_old.forward_old(self, texts)
|
||||
|
||||
batch_chunks, token_count = self.process_texts(texts)
|
||||
|
||||
used_embeddings = {}
|
||||
chunk_count = max([len(x) for x in batch_chunks])
|
||||
|
||||
zs = []
|
||||
for i in range(chunk_count):
|
||||
batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
|
||||
|
||||
tokens = [x.tokens for x in batch_chunk]
|
||||
multipliers = [x.multipliers for x in batch_chunk]
|
||||
self.hijack.fixes = [x.fixes for x in batch_chunk]
|
||||
|
||||
for fixes in self.hijack.fixes:
|
||||
for position, embedding in fixes:
|
||||
used_embeddings[embedding.name] = embedding
|
||||
|
||||
z = self.process_tokens(tokens, multipliers)
|
||||
zs.append(z)
|
||||
|
||||
if len(used_embeddings) > 0:
|
||||
embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
|
||||
self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
|
||||
|
||||
return torch.hstack(zs)
|
||||
|
||||
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||
"""
|
||||
sends one single prompt chunk to be encoded by transformers neural network.
|
||||
remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
|
||||
there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
|
||||
Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
|
||||
corresponds to one token.
|
||||
"""
|
||||
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
|
||||
|
||||
# this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
|
||||
if self.id_end != self.id_pad:
|
||||
for batch_pos in range(len(remade_batch_tokens)):
|
||||
index = remade_batch_tokens[batch_pos].index(self.id_end)
|
||||
tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
|
||||
|
||||
z = self.encode_with_transformers(tokens)
|
||||
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
|
||||
original_mean = z.mean()
|
||||
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||
new_mean = z.mean()
|
||||
z = z * (original_mean / new_mean)
|
||||
|
||||
return z
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
self.tokenizer = wrapped.tokenizer
|
||||
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
|
||||
self.comma_token = vocab.get(',</w>', None)
|
||||
|
||||
self.token_mults = {}
|
||||
tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||
for text, ident in tokens_with_parens:
|
||||
mult = 1.0
|
||||
for c in text:
|
||||
if c == '[':
|
||||
mult /= 1.1
|
||||
if c == ']':
|
||||
mult *= 1.1
|
||||
if c == '(':
|
||||
mult *= 1.1
|
||||
if c == ')':
|
||||
mult /= 1.1
|
||||
|
||||
if mult != 1.0:
|
||||
self.token_mults[ident] = mult
|
||||
|
||||
self.id_start = self.wrapped.tokenizer.bos_token_id
|
||||
self.id_end = self.wrapped.tokenizer.eos_token_id
|
||||
self.id_pad = self.id_end
|
||||
|
||||
def tokenize(self, texts):
|
||||
tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
|
||||
return tokenized
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||
|
||||
if opts.CLIP_stop_at_last_layers > 1:
|
||||
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
|
||||
z = self.wrapped.transformer.text_model.final_layer_norm(z)
|
||||
else:
|
||||
z = outputs.last_hidden_state
|
||||
|
||||
return z
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
embedding_layer = self.wrapped.transformer.text_model.embeddings
|
||||
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
|
||||
|
||||
return embedded
|
||||
@ -0,0 +1,81 @@
|
||||
from modules import sd_hijack_clip
|
||||
from modules import shared
|
||||
|
||||
|
||||
def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
|
||||
id_start = self.id_start
|
||||
id_end = self.id_end
|
||||
maxlen = self.wrapped.max_length # you get to stay at 77
|
||||
used_custom_terms = []
|
||||
remade_batch_tokens = []
|
||||
hijack_comments = []
|
||||
hijack_fixes = []
|
||||
token_count = 0
|
||||
|
||||
cache = {}
|
||||
batch_tokens = self.tokenize(texts)
|
||||
batch_multipliers = []
|
||||
for tokens in batch_tokens:
|
||||
tuple_tokens = tuple(tokens)
|
||||
|
||||
if tuple_tokens in cache:
|
||||
remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
||||
else:
|
||||
fixes = []
|
||||
remade_tokens = []
|
||||
multipliers = []
|
||||
mult = 1.0
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
|
||||
if mult_change is not None:
|
||||
mult *= mult_change
|
||||
i += 1
|
||||
elif embedding is None:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(mult)
|
||||
i += 1
|
||||
else:
|
||||
emb_len = int(embedding.vec.shape[0])
|
||||
fixes.append((len(remade_tokens), embedding))
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [mult] * emb_len
|
||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||
i += embedding_length_in_tokens
|
||||
|
||||
if len(remade_tokens) > maxlen - 2:
|
||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||
ovf = remade_tokens[maxlen - 2:]
|
||||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||
|
||||
token_count = len(remade_tokens)
|
||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
||||
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||
|
||||
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
||||
|
||||
remade_batch_tokens.append(remade_tokens)
|
||||
hijack_fixes.append(fixes)
|
||||
batch_multipliers.append(multipliers)
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
|
||||
def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
|
||||
|
||||
self.hijack.comments += hijack_comments
|
||||
|
||||
if len(used_custom_terms) > 0:
|
||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||
|
||||
self.hijack.fixes = hijack_fixes
|
||||
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
||||
@ -0,0 +1,111 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
from einops import repeat
|
||||
from omegaconf import ListConfig
|
||||
|
||||
import ldm.models.diffusion.ddpm
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [
|
||||
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||
for i in range(len(c[k]))
|
||||
]
|
||||
else:
|
||||
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
if dynamic_threshold is not None:
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
e_t = get_model_output(x, t)
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = get_model_output(x_prev, t_next)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
return x_prev, pred_x0, e_t
|
||||
|
||||
|
||||
def should_hijack_inpainting(checkpoint_info):
|
||||
from modules import sd_models
|
||||
|
||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||
cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
|
||||
|
||||
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
|
||||
|
||||
|
||||
def do_inpainting_hijack():
|
||||
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
||||
|
||||
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
||||
@ -0,0 +1,37 @@
|
||||
import open_clip.tokenizer
|
||||
import torch
|
||||
|
||||
from modules import sd_hijack_clip, devices
|
||||
from modules.shared import opts
|
||||
|
||||
tokenizer = open_clip.tokenizer._tokenizer
|
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
|
||||
self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
|
||||
self.id_start = tokenizer.encoder["<start_of_text>"]
|
||||
self.id_end = tokenizer.encoder["<end_of_text>"]
|
||||
self.id_pad = 0
|
||||
|
||||
def tokenize(self, texts):
|
||||
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
|
||||
|
||||
tokenized = [tokenizer.encode(text) for text in texts]
|
||||
|
||||
return tokenized
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
# set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
|
||||
z = self.wrapped.encode_with_transformer(tokens)
|
||||
|
||||
return z
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
ids = tokenizer.encode(init_text)
|
||||
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
||||
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
|
||||
|
||||
return embedded
|
||||
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
|
||||
|
||||
class TorchHijackForUnet:
|
||||
"""
|
||||
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
||||
this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
|
||||
"""
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'cat':
|
||||
return self.cat
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
||||
|
||||
def cat(self, tensors, *args, **kwargs):
|
||||
if len(tensors) == 2:
|
||||
a, b = tensors
|
||||
if a.shape[-2:] != b.shape[-2:]:
|
||||
a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
|
||||
|
||||
tensors = (a, b)
|
||||
|
||||
return torch.cat(tensors, *args, **kwargs)
|
||||
|
||||
|
||||
th = TorchHijackForUnet()
|
||||
@ -0,0 +1,34 @@
|
||||
import open_clip.tokenizer
|
||||
import torch
|
||||
|
||||
from modules import sd_hijack_clip, devices
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
|
||||
self.id_start = wrapped.config.bos_token_id
|
||||
self.id_end = wrapped.config.eos_token_id
|
||||
self.id_pad = wrapped.config.pad_token_id
|
||||
|
||||
self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have </w> bits for comma
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
# there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
|
||||
# trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
|
||||
# layer to work with - you have to use the last
|
||||
|
||||
attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
|
||||
features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
|
||||
z = features['projection_state']
|
||||
|
||||
return z
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
embedding_layer = self.wrapped.roberta.embeddings
|
||||
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
||||
|
||||
return embedded
|
||||
@ -0,0 +1,243 @@
|
||||
import torch
|
||||
import safetensors.torch
|
||||
import os
|
||||
import collections
|
||||
from collections import namedtuple
|
||||
from modules import shared, devices, script_callbacks, sd_models
|
||||
from modules.paths import models_path
|
||||
import glob
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||
vae_dir = "VAE"
|
||||
vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
|
||||
|
||||
|
||||
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
|
||||
|
||||
|
||||
default_vae_dict = {"auto": "auto", "None": None, None: None}
|
||||
default_vae_list = ["auto", "None"]
|
||||
|
||||
|
||||
default_vae_values = [default_vae_dict[x] for x in default_vae_list]
|
||||
vae_dict = dict(default_vae_dict)
|
||||
vae_list = list(default_vae_list)
|
||||
first_load = True
|
||||
|
||||
|
||||
base_vae = None
|
||||
loaded_vae_file = None
|
||||
checkpoint_info = None
|
||||
|
||||
checkpoints_loaded = collections.OrderedDict()
|
||||
|
||||
def get_base_vae(model):
|
||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
||||
return base_vae
|
||||
return None
|
||||
|
||||
|
||||
def store_base_vae(model):
|
||||
global base_vae, checkpoint_info
|
||||
if checkpoint_info != model.sd_checkpoint_info:
|
||||
assert not loaded_vae_file, "Trying to store non-base VAE!"
|
||||
base_vae = deepcopy(model.first_stage_model.state_dict())
|
||||
checkpoint_info = model.sd_checkpoint_info
|
||||
|
||||
|
||||
def delete_base_vae():
|
||||
global base_vae, checkpoint_info
|
||||
base_vae = None
|
||||
checkpoint_info = None
|
||||
|
||||
|
||||
def restore_base_vae(model):
|
||||
global loaded_vae_file
|
||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
|
||||
print("Restoring base VAE")
|
||||
_load_vae_dict(model, base_vae)
|
||||
loaded_vae_file = None
|
||||
delete_base_vae()
|
||||
|
||||
|
||||
def get_filename(filepath):
|
||||
return os.path.splitext(os.path.basename(filepath))[0]
|
||||
|
||||
|
||||
def refresh_vae_list(vae_path=vae_path, model_path=model_path):
|
||||
global vae_dict, vae_list
|
||||
res = {}
|
||||
candidates = [
|
||||
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
|
||||
*glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
|
||||
*glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True),
|
||||
*glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
|
||||
*glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True),
|
||||
*glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True),
|
||||
]
|
||||
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
|
||||
candidates.append(shared.cmd_opts.vae_path)
|
||||
for filepath in candidates:
|
||||
name = get_filename(filepath)
|
||||
res[name] = filepath
|
||||
vae_list.clear()
|
||||
vae_list.extend(default_vae_list)
|
||||
vae_list.extend(list(res.keys()))
|
||||
vae_dict.clear()
|
||||
vae_dict.update(res)
|
||||
vae_dict.update(default_vae_dict)
|
||||
return vae_list
|
||||
|
||||
|
||||
def get_vae_from_settings(vae_file="auto"):
|
||||
# else, we load from settings, if not set to be default
|
||||
if vae_file == "auto" and shared.opts.sd_vae is not None:
|
||||
# if saved VAE settings isn't recognized, fallback to auto
|
||||
vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
|
||||
# if VAE selected but not found, fallback to auto
|
||||
if vae_file not in default_vae_values and not os.path.isfile(vae_file):
|
||||
vae_file = "auto"
|
||||
print(f"Selected VAE doesn't exist: {vae_file}")
|
||||
return vae_file
|
||||
|
||||
|
||||
def resolve_vae(checkpoint_file=None, vae_file="auto"):
|
||||
global first_load, vae_dict, vae_list
|
||||
|
||||
# if vae_file argument is provided, it takes priority, but not saved
|
||||
if vae_file and vae_file not in default_vae_list:
|
||||
if not os.path.isfile(vae_file):
|
||||
print(f"VAE provided as function argument doesn't exist: {vae_file}")
|
||||
vae_file = "auto"
|
||||
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
|
||||
if first_load and shared.cmd_opts.vae_path is not None:
|
||||
if os.path.isfile(shared.cmd_opts.vae_path):
|
||||
vae_file = shared.cmd_opts.vae_path
|
||||
shared.opts.data['sd_vae'] = get_filename(vae_file)
|
||||
else:
|
||||
print(f"VAE provided as command line argument doesn't exist: {vae_file}")
|
||||
# fallback to selector in settings, if vae selector not set to act as default fallback
|
||||
if not shared.opts.sd_vae_as_default:
|
||||
vae_file = get_vae_from_settings(vae_file)
|
||||
# vae-path cmd arg takes priority for auto
|
||||
if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
|
||||
if os.path.isfile(shared.cmd_opts.vae_path):
|
||||
vae_file = shared.cmd_opts.vae_path
|
||||
print(f"Using VAE provided as command line argument: {vae_file}")
|
||||
# if still not found, try look for ".vae.pt" beside model
|
||||
model_path = os.path.splitext(checkpoint_file)[0]
|
||||
if vae_file == "auto":
|
||||
vae_file_try = model_path + ".vae.pt"
|
||||
if os.path.isfile(vae_file_try):
|
||||
vae_file = vae_file_try
|
||||
print(f"Using VAE found similar to selected model: {vae_file}")
|
||||
# if still not found, try look for ".vae.ckpt" beside model
|
||||
if vae_file == "auto":
|
||||
vae_file_try = model_path + ".vae.ckpt"
|
||||
if os.path.isfile(vae_file_try):
|
||||
vae_file = vae_file_try
|
||||
print(f"Using VAE found similar to selected model: {vae_file}")
|
||||
# if still not found, try look for ".vae.safetensors" beside model
|
||||
if vae_file == "auto":
|
||||
vae_file_try = model_path + ".vae.safetensors"
|
||||
if os.path.isfile(vae_file_try):
|
||||
vae_file = vae_file_try
|
||||
print(f"Using VAE found similar to selected model: {vae_file}")
|
||||
# No more fallbacks for auto
|
||||
if vae_file == "auto":
|
||||
vae_file = None
|
||||
# Last check, just because
|
||||
if vae_file and not os.path.exists(vae_file):
|
||||
vae_file = None
|
||||
|
||||
return vae_file
|
||||
|
||||
|
||||
def load_vae(model, vae_file=None):
|
||||
global first_load, vae_dict, vae_list, loaded_vae_file
|
||||
# save_settings = False
|
||||
|
||||
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
||||
|
||||
if vae_file:
|
||||
if cache_enabled and vae_file in checkpoints_loaded:
|
||||
# use vae checkpoint cache
|
||||
print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
|
||||
store_base_vae(model)
|
||||
_load_vae_dict(model, checkpoints_loaded[vae_file])
|
||||
else:
|
||||
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
|
||||
print(f"Loading VAE weights from: {vae_file}")
|
||||
store_base_vae(model)
|
||||
|
||||
vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location)
|
||||
vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
_load_vae_dict(model, vae_dict_1)
|
||||
|
||||
if cache_enabled:
|
||||
# cache newly loaded vae
|
||||
checkpoints_loaded[vae_file] = vae_dict_1.copy()
|
||||
|
||||
# clean up cache if limit is reached
|
||||
if cache_enabled:
|
||||
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
|
||||
checkpoints_loaded.popitem(last=False) # LRU
|
||||
|
||||
# If vae used is not in dict, update it
|
||||
# It will be removed on refresh though
|
||||
vae_opt = get_filename(vae_file)
|
||||
if vae_opt not in vae_dict:
|
||||
vae_dict[vae_opt] = vae_file
|
||||
vae_list.append(vae_opt)
|
||||
elif loaded_vae_file:
|
||||
restore_base_vae(model)
|
||||
|
||||
loaded_vae_file = vae_file
|
||||
|
||||
first_load = False
|
||||
|
||||
|
||||
# don't call this from outside
|
||||
def _load_vae_dict(model, vae_dict_1):
|
||||
model.first_stage_model.load_state_dict(vae_dict_1)
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
|
||||
|
||||
def clear_loaded_vae():
|
||||
global loaded_vae_file
|
||||
loaded_vae_file = None
|
||||
|
||||
|
||||
def reload_vae_weights(sd_model=None, vae_file="auto"):
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
|
||||
if not sd_model:
|
||||
sd_model = shared.sd_model
|
||||
|
||||
checkpoint_info = sd_model.sd_checkpoint_info
|
||||
checkpoint_file = checkpoint_info.filename
|
||||
vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
|
||||
|
||||
if loaded_vae_file == vae_file:
|
||||
return
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
sd_model.to(devices.cpu)
|
||||
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
load_vae(sd_model, vae_file)
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
sd_model.to(devices.device)
|
||||
|
||||
print("VAE Weights loaded.")
|
||||
return sd_model
|
||||
@ -0,0 +1,58 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from modules import devices, paths
|
||||
|
||||
sd_vae_approx_model = None
|
||||
|
||||
|
||||
class VAEApprox(nn.Module):
|
||||
def __init__(self):
|
||||
super(VAEApprox, self).__init__()
|
||||
self.conv1 = nn.Conv2d(4, 8, (7, 7))
|
||||
self.conv2 = nn.Conv2d(8, 16, (5, 5))
|
||||
self.conv3 = nn.Conv2d(16, 32, (3, 3))
|
||||
self.conv4 = nn.Conv2d(32, 64, (3, 3))
|
||||
self.conv5 = nn.Conv2d(64, 32, (3, 3))
|
||||
self.conv6 = nn.Conv2d(32, 16, (3, 3))
|
||||
self.conv7 = nn.Conv2d(16, 8, (3, 3))
|
||||
self.conv8 = nn.Conv2d(8, 3, (3, 3))
|
||||
|
||||
def forward(self, x):
|
||||
extra = 11
|
||||
x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
|
||||
x = nn.functional.pad(x, (extra, extra, extra, extra))
|
||||
|
||||
for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:
|
||||
x = layer(x)
|
||||
x = nn.functional.leaky_relu(x, 0.1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def model():
|
||||
global sd_vae_approx_model
|
||||
|
||||
if sd_vae_approx_model is None:
|
||||
sd_vae_approx_model = VAEApprox()
|
||||
sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt")))
|
||||
sd_vae_approx_model.eval()
|
||||
sd_vae_approx_model.to(devices.device, devices.dtype)
|
||||
|
||||
return sd_vae_approx_model
|
||||
|
||||
|
||||
def cheap_approximation(sample):
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
|
||||
|
||||
coefs = torch.tensor([
|
||||
[0.298, 0.207, 0.208],
|
||||
[0.187, 0.286, 0.173],
|
||||
[-0.158, 0.189, 0.264],
|
||||
[-0.184, -0.271, -0.473],
|
||||
]).to(sample.device)
|
||||
|
||||
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
|
||||
|
||||
return x_sample
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue