Merge branch 'master' into tensorboard
commit
9cd7716753
@ -0,0 +1,5 @@
|
|||||||
|
blank_issues_enabled: false
|
||||||
|
contact_links:
|
||||||
|
- name: WebUI Community Support
|
||||||
|
url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions
|
||||||
|
about: Please ask and answer questions here.
|
||||||
@ -0,0 +1,29 @@
|
|||||||
|
name: Run basic features tests on CPU with empty SD model
|
||||||
|
|
||||||
|
on:
|
||||||
|
- push
|
||||||
|
- pull_request
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout Code
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
- name: Set up Python 3.10
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: 3.10.6
|
||||||
|
cache: pip
|
||||||
|
cache-dependency-path: |
|
||||||
|
**/requirements*txt
|
||||||
|
- name: Run tests
|
||||||
|
run: python launch.py --tests --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
|
||||||
|
- name: Upload main app stdout-stderr
|
||||||
|
uses: actions/upload-artifact@v3
|
||||||
|
if: always()
|
||||||
|
with:
|
||||||
|
name: stdout-stderr
|
||||||
|
path: |
|
||||||
|
test/stdout.txt
|
||||||
|
test/stderr.txt
|
||||||
@ -1 +1,12 @@
|
|||||||
* @AUTOMATIC1111
|
* @AUTOMATIC1111
|
||||||
|
|
||||||
|
# if you were managing a localization and were removed from this file, this is because
|
||||||
|
# the intended way to do localizations now is via extensions. See:
|
||||||
|
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions
|
||||||
|
# Make a repo with your localization and since you are still listed as a collaborator
|
||||||
|
# you can add it to the wiki page yourself. This change is because some people complained
|
||||||
|
# the git commit log is cluttered with things unrelated to almost everyone and
|
||||||
|
# because I believe this is the best overall for the project to handle localizations almost
|
||||||
|
# entirely without my oversight.
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,72 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: modules.xlmr.BertSeriesModelWithTransformation
|
||||||
|
params:
|
||||||
|
name: "XLMR-Large"
|
||||||
@ -0,0 +1,70 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
@ -0,0 +1,6 @@
|
|||||||
|
import os
|
||||||
|
from modules import paths
|
||||||
|
|
||||||
|
|
||||||
|
def preload(parser):
|
||||||
|
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
|
||||||
@ -0,0 +1,286 @@
|
|||||||
|
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
|
||||||
|
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
|
||||||
|
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||||
|
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
|
import ldm.models.autoencoder
|
||||||
|
|
||||||
|
class VQModel(pl.LightningModule):
|
||||||
|
def __init__(self,
|
||||||
|
ddconfig,
|
||||||
|
lossconfig,
|
||||||
|
n_embed,
|
||||||
|
embed_dim,
|
||||||
|
ckpt_path=None,
|
||||||
|
ignore_keys=[],
|
||||||
|
image_key="image",
|
||||||
|
colorize_nlabels=None,
|
||||||
|
monitor=None,
|
||||||
|
batch_resize_range=None,
|
||||||
|
scheduler_config=None,
|
||||||
|
lr_g_factor=1.0,
|
||||||
|
remap=None,
|
||||||
|
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||||
|
use_ema=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.n_embed = n_embed
|
||||||
|
self.image_key = image_key
|
||||||
|
self.encoder = Encoder(**ddconfig)
|
||||||
|
self.decoder = Decoder(**ddconfig)
|
||||||
|
self.loss = instantiate_from_config(lossconfig)
|
||||||
|
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
||||||
|
remap=remap,
|
||||||
|
sane_index_shape=sane_index_shape)
|
||||||
|
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
||||||
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||||
|
if colorize_nlabels is not None:
|
||||||
|
assert type(colorize_nlabels)==int
|
||||||
|
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||||
|
if monitor is not None:
|
||||||
|
self.monitor = monitor
|
||||||
|
self.batch_resize_range = batch_resize_range
|
||||||
|
if self.batch_resize_range is not None:
|
||||||
|
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
||||||
|
|
||||||
|
self.use_ema = use_ema
|
||||||
|
if self.use_ema:
|
||||||
|
self.model_ema = LitEma(self)
|
||||||
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
|
if ckpt_path is not None:
|
||||||
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.lr_g_factor = lr_g_factor
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def ema_scope(self, context=None):
|
||||||
|
if self.use_ema:
|
||||||
|
self.model_ema.store(self.parameters())
|
||||||
|
self.model_ema.copy_to(self)
|
||||||
|
if context is not None:
|
||||||
|
print(f"{context}: Switched to EMA weights")
|
||||||
|
try:
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
if self.use_ema:
|
||||||
|
self.model_ema.restore(self.parameters())
|
||||||
|
if context is not None:
|
||||||
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
|
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||||
|
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||||
|
keys = list(sd.keys())
|
||||||
|
for k in keys:
|
||||||
|
for ik in ignore_keys:
|
||||||
|
if k.startswith(ik):
|
||||||
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
|
del sd[k]
|
||||||
|
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||||
|
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||||
|
if len(missing) > 0:
|
||||||
|
print(f"Missing Keys: {missing}")
|
||||||
|
print(f"Unexpected Keys: {unexpected}")
|
||||||
|
|
||||||
|
def on_train_batch_end(self, *args, **kwargs):
|
||||||
|
if self.use_ema:
|
||||||
|
self.model_ema(self)
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
h = self.quant_conv(h)
|
||||||
|
quant, emb_loss, info = self.quantize(h)
|
||||||
|
return quant, emb_loss, info
|
||||||
|
|
||||||
|
def encode_to_prequant(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
h = self.quant_conv(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
def decode(self, quant):
|
||||||
|
quant = self.post_quant_conv(quant)
|
||||||
|
dec = self.decoder(quant)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
def decode_code(self, code_b):
|
||||||
|
quant_b = self.quantize.embed_code(code_b)
|
||||||
|
dec = self.decode(quant_b)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
def forward(self, input, return_pred_indices=False):
|
||||||
|
quant, diff, (_,_,ind) = self.encode(input)
|
||||||
|
dec = self.decode(quant)
|
||||||
|
if return_pred_indices:
|
||||||
|
return dec, diff, ind
|
||||||
|
return dec, diff
|
||||||
|
|
||||||
|
def get_input(self, batch, k):
|
||||||
|
x = batch[k]
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
x = x[..., None]
|
||||||
|
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||||
|
if self.batch_resize_range is not None:
|
||||||
|
lower_size = self.batch_resize_range[0]
|
||||||
|
upper_size = self.batch_resize_range[1]
|
||||||
|
if self.global_step <= 4:
|
||||||
|
# do the first few batches with max size to avoid later oom
|
||||||
|
new_resize = upper_size
|
||||||
|
else:
|
||||||
|
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
||||||
|
if new_resize != x.shape[2]:
|
||||||
|
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
||||||
|
x = x.detach()
|
||||||
|
return x
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||||
|
# https://github.com/pytorch/pytorch/issues/37142
|
||||||
|
# try not to fool the heuristics
|
||||||
|
x = self.get_input(batch, self.image_key)
|
||||||
|
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||||
|
|
||||||
|
if optimizer_idx == 0:
|
||||||
|
# autoencode
|
||||||
|
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||||
|
last_layer=self.get_last_layer(), split="train",
|
||||||
|
predicted_indices=ind)
|
||||||
|
|
||||||
|
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||||
|
return aeloss
|
||||||
|
|
||||||
|
if optimizer_idx == 1:
|
||||||
|
# discriminator
|
||||||
|
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||||
|
last_layer=self.get_last_layer(), split="train")
|
||||||
|
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||||
|
return discloss
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
log_dict = self._validation_step(batch, batch_idx)
|
||||||
|
with self.ema_scope():
|
||||||
|
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
||||||
|
return log_dict
|
||||||
|
|
||||||
|
def _validation_step(self, batch, batch_idx, suffix=""):
|
||||||
|
x = self.get_input(batch, self.image_key)
|
||||||
|
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||||
|
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
||||||
|
self.global_step,
|
||||||
|
last_layer=self.get_last_layer(),
|
||||||
|
split="val"+suffix,
|
||||||
|
predicted_indices=ind
|
||||||
|
)
|
||||||
|
|
||||||
|
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
||||||
|
self.global_step,
|
||||||
|
last_layer=self.get_last_layer(),
|
||||||
|
split="val"+suffix,
|
||||||
|
predicted_indices=ind
|
||||||
|
)
|
||||||
|
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
||||||
|
self.log(f"val{suffix}/rec_loss", rec_loss,
|
||||||
|
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||||
|
self.log(f"val{suffix}/aeloss", aeloss,
|
||||||
|
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||||
|
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||||
|
del log_dict_ae[f"val{suffix}/rec_loss"]
|
||||||
|
self.log_dict(log_dict_ae)
|
||||||
|
self.log_dict(log_dict_disc)
|
||||||
|
return self.log_dict
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
lr_d = self.learning_rate
|
||||||
|
lr_g = self.lr_g_factor*self.learning_rate
|
||||||
|
print("lr_d", lr_d)
|
||||||
|
print("lr_g", lr_g)
|
||||||
|
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||||
|
list(self.decoder.parameters())+
|
||||||
|
list(self.quantize.parameters())+
|
||||||
|
list(self.quant_conv.parameters())+
|
||||||
|
list(self.post_quant_conv.parameters()),
|
||||||
|
lr=lr_g, betas=(0.5, 0.9))
|
||||||
|
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||||
|
lr=lr_d, betas=(0.5, 0.9))
|
||||||
|
|
||||||
|
if self.scheduler_config is not None:
|
||||||
|
scheduler = instantiate_from_config(self.scheduler_config)
|
||||||
|
|
||||||
|
print("Setting up LambdaLR scheduler...")
|
||||||
|
scheduler = [
|
||||||
|
{
|
||||||
|
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
||||||
|
'interval': 'step',
|
||||||
|
'frequency': 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
||||||
|
'interval': 'step',
|
||||||
|
'frequency': 1
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return [opt_ae, opt_disc], scheduler
|
||||||
|
return [opt_ae, opt_disc], []
|
||||||
|
|
||||||
|
def get_last_layer(self):
|
||||||
|
return self.decoder.conv_out.weight
|
||||||
|
|
||||||
|
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||||||
|
log = dict()
|
||||||
|
x = self.get_input(batch, self.image_key)
|
||||||
|
x = x.to(self.device)
|
||||||
|
if only_inputs:
|
||||||
|
log["inputs"] = x
|
||||||
|
return log
|
||||||
|
xrec, _ = self(x)
|
||||||
|
if x.shape[1] > 3:
|
||||||
|
# colorize with random projection
|
||||||
|
assert xrec.shape[1] > 3
|
||||||
|
x = self.to_rgb(x)
|
||||||
|
xrec = self.to_rgb(xrec)
|
||||||
|
log["inputs"] = x
|
||||||
|
log["reconstructions"] = xrec
|
||||||
|
if plot_ema:
|
||||||
|
with self.ema_scope():
|
||||||
|
xrec_ema, _ = self(x)
|
||||||
|
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
||||||
|
log["reconstructions_ema"] = xrec_ema
|
||||||
|
return log
|
||||||
|
|
||||||
|
def to_rgb(self, x):
|
||||||
|
assert self.image_key == "segmentation"
|
||||||
|
if not hasattr(self, "colorize"):
|
||||||
|
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||||
|
x = F.conv2d(x, weight=self.colorize)
|
||||||
|
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VQModelInterface(VQModel):
|
||||||
|
def __init__(self, embed_dim, *args, **kwargs):
|
||||||
|
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
h = self.quant_conv(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
def decode(self, h, force_not_quantize=False):
|
||||||
|
# also go through quantization layer
|
||||||
|
if not force_not_quantize:
|
||||||
|
quant, emb_loss, info = self.quantize(h)
|
||||||
|
else:
|
||||||
|
quant = h
|
||||||
|
quant = self.post_quant_conv(quant)
|
||||||
|
dec = self.decoder(quant)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
setattr(ldm.models.autoencoder, "VQModel", VQModel)
|
||||||
|
setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)
|
||||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,6 @@
|
|||||||
|
import os
|
||||||
|
from modules import paths
|
||||||
|
|
||||||
|
|
||||||
|
def preload(parser):
|
||||||
|
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
|
||||||
@ -0,0 +1,6 @@
|
|||||||
|
import os
|
||||||
|
from modules import paths
|
||||||
|
|
||||||
|
|
||||||
|
def preload(parser):
|
||||||
|
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
|
||||||
@ -0,0 +1,107 @@
|
|||||||
|
// Stable Diffusion WebUI - Bracket checker
|
||||||
|
// Version 1.0
|
||||||
|
// By Hingashi no Florin/Bwin4L
|
||||||
|
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
|
||||||
|
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
|
||||||
|
|
||||||
|
function checkBrackets(evt) {
|
||||||
|
textArea = evt.target;
|
||||||
|
tabName = evt.target.parentElement.parentElement.id.split("_")[0];
|
||||||
|
counterElt = document.querySelector('gradio-app').shadowRoot.querySelector('#' + tabName + '_token_counter');
|
||||||
|
|
||||||
|
promptName = evt.target.parentElement.parentElement.id.includes('neg') ? ' negative' : '';
|
||||||
|
|
||||||
|
errorStringParen = '(' + tabName + promptName + ' prompt) - Different number of opening and closing parentheses detected.\n';
|
||||||
|
errorStringSquare = '[' + tabName + promptName + ' prompt] - Different number of opening and closing square brackets detected.\n';
|
||||||
|
errorStringCurly = '{' + tabName + promptName + ' prompt} - Different number of opening and closing curly brackets detected.\n';
|
||||||
|
|
||||||
|
openBracketRegExp = /\(/g;
|
||||||
|
closeBracketRegExp = /\)/g;
|
||||||
|
|
||||||
|
openSquareBracketRegExp = /\[/g;
|
||||||
|
closeSquareBracketRegExp = /\]/g;
|
||||||
|
|
||||||
|
openCurlyBracketRegExp = /\{/g;
|
||||||
|
closeCurlyBracketRegExp = /\}/g;
|
||||||
|
|
||||||
|
totalOpenBracketMatches = 0;
|
||||||
|
totalCloseBracketMatches = 0;
|
||||||
|
totalOpenSquareBracketMatches = 0;
|
||||||
|
totalCloseSquareBracketMatches = 0;
|
||||||
|
totalOpenCurlyBracketMatches = 0;
|
||||||
|
totalCloseCurlyBracketMatches = 0;
|
||||||
|
|
||||||
|
openBracketMatches = textArea.value.match(openBracketRegExp);
|
||||||
|
if(openBracketMatches) {
|
||||||
|
totalOpenBracketMatches = openBracketMatches.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
closeBracketMatches = textArea.value.match(closeBracketRegExp);
|
||||||
|
if(closeBracketMatches) {
|
||||||
|
totalCloseBracketMatches = closeBracketMatches.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp);
|
||||||
|
if(openSquareBracketMatches) {
|
||||||
|
totalOpenSquareBracketMatches = openSquareBracketMatches.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp);
|
||||||
|
if(closeSquareBracketMatches) {
|
||||||
|
totalCloseSquareBracketMatches = closeSquareBracketMatches.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp);
|
||||||
|
if(openCurlyBracketMatches) {
|
||||||
|
totalOpenCurlyBracketMatches = openCurlyBracketMatches.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp);
|
||||||
|
if(closeCurlyBracketMatches) {
|
||||||
|
totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(totalOpenBracketMatches != totalCloseBracketMatches) {
|
||||||
|
if(!counterElt.title.includes(errorStringParen)) {
|
||||||
|
counterElt.title += errorStringParen;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
counterElt.title = counterElt.title.replace(errorStringParen, '');
|
||||||
|
}
|
||||||
|
|
||||||
|
if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) {
|
||||||
|
if(!counterElt.title.includes(errorStringSquare)) {
|
||||||
|
counterElt.title += errorStringSquare;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
counterElt.title = counterElt.title.replace(errorStringSquare, '');
|
||||||
|
}
|
||||||
|
|
||||||
|
if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) {
|
||||||
|
if(!counterElt.title.includes(errorStringCurly)) {
|
||||||
|
counterElt.title += errorStringCurly;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
counterElt.title = counterElt.title.replace(errorStringCurly, '');
|
||||||
|
}
|
||||||
|
|
||||||
|
if(counterElt.title != '') {
|
||||||
|
counterElt.style = 'color: #FF5555;';
|
||||||
|
} else {
|
||||||
|
counterElt.style = '';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var shadowRootLoaded = setInterval(function() {
|
||||||
|
var shadowTextArea = document.querySelector('gradio-app').shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
|
||||||
|
if(shadowTextArea.length < 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
clearInterval(shadowRootLoaded);
|
||||||
|
|
||||||
|
document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_prompt').onkeyup = checkBrackets;
|
||||||
|
document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_neg_prompt').onkeyup = checkBrackets;
|
||||||
|
document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_prompt').onkeyup = checkBrackets;
|
||||||
|
document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_neg_prompt').onkeyup = checkBrackets;
|
||||||
|
}, 1000);
|
||||||
@ -0,0 +1,50 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
|
from modules import script_callbacks, shared
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
art_symbol = '\U0001f3a8' # 🎨
|
||||||
|
global_prompt = None
|
||||||
|
related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" }
|
||||||
|
|
||||||
|
|
||||||
|
def roll_artist(prompt):
|
||||||
|
allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories])
|
||||||
|
artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
|
||||||
|
|
||||||
|
return prompt + ", " + artist.name if prompt != '' else artist.name
|
||||||
|
|
||||||
|
|
||||||
|
def add_roll_button(prompt):
|
||||||
|
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
||||||
|
|
||||||
|
roll.click(
|
||||||
|
fn=roll_artist,
|
||||||
|
_js="update_txt2img_tokens",
|
||||||
|
inputs=[
|
||||||
|
prompt,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
prompt,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def after_component(component, **kwargs):
|
||||||
|
global global_prompt
|
||||||
|
|
||||||
|
elem_id = kwargs.get('elem_id', None)
|
||||||
|
if elem_id not in related_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
if elem_id == "txt2img_prompt":
|
||||||
|
global_prompt = component
|
||||||
|
elif elem_id == "txt2img_clear_prompt":
|
||||||
|
add_roll_button(global_prompt)
|
||||||
|
elif elem_id == "img2img_prompt":
|
||||||
|
global_prompt = component
|
||||||
|
elif elem_id == "img2img_clear_prompt":
|
||||||
|
add_roll_button(global_prompt)
|
||||||
|
|
||||||
|
|
||||||
|
script_callbacks.on_after_component(after_component)
|
||||||
@ -0,0 +1,13 @@
|
|||||||
|
<div>
|
||||||
|
<a href="/docs">API</a>
|
||||||
|
•
|
||||||
|
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a>
|
||||||
|
•
|
||||||
|
<a href="https://gradio.app">Gradio</a>
|
||||||
|
•
|
||||||
|
<a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a>
|
||||||
|
</div>
|
||||||
|
<br />
|
||||||
|
<div class="versions">
|
||||||
|
{versions}
|
||||||
|
</div>
|
||||||
@ -0,0 +1,419 @@
|
|||||||
|
<style>
|
||||||
|
#licenses h2 {font-size: 1.2em; font-weight: bold; margin-bottom: 0.2em;}
|
||||||
|
#licenses small {font-size: 0.95em; opacity: 0.85;}
|
||||||
|
#licenses pre { margin: 1em 0 2em 0;}
|
||||||
|
</style>
|
||||||
|
|
||||||
|
<h2><a href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">CodeFormer</a></h2>
|
||||||
|
<small>Parts of CodeFormer code had to be copied to be compatible with GFPGAN.</small>
|
||||||
|
<pre>
|
||||||
|
S-Lab License 1.0
|
||||||
|
|
||||||
|
Copyright 2022 S-Lab
|
||||||
|
|
||||||
|
Redistribution and use for non-commercial purpose in source and
|
||||||
|
binary forms, with or without modification, are permitted provided
|
||||||
|
that the following conditions are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer in
|
||||||
|
the documentation and/or other materials provided with the
|
||||||
|
distribution.
|
||||||
|
|
||||||
|
3. Neither the name of the copyright holder nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived
|
||||||
|
from this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
In the event that redistribution and/or use for commercial purpose in
|
||||||
|
source or binary forms, with or without modification is required,
|
||||||
|
please contact the contributor(s) of the work.
|
||||||
|
</pre>
|
||||||
|
|
||||||
|
|
||||||
|
<h2><a href="https://github.com/victorca25/iNNfer/blob/main/LICENSE">ESRGAN</a></h2>
|
||||||
|
<small>Code for architecture and reading models copied.</small>
|
||||||
|
<pre>
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2021 victorca25
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
</pre>
|
||||||
|
|
||||||
|
<h2><a href="https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE">Real-ESRGAN</a></h2>
|
||||||
|
<small>Some code is copied to support ESRGAN models.</small>
|
||||||
|
<pre>
|
||||||
|
BSD 3-Clause License
|
||||||
|
|
||||||
|
Copyright (c) 2021, Xintao Wang
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
3. Neither the name of the copyright holder nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
</pre>
|
||||||
|
|
||||||
|
<h2><a href="https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE">InvokeAI</a></h2>
|
||||||
|
<small>Some code for compatibility with OSX is taken from lstein's repository.</small>
|
||||||
|
<pre>
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2022 InvokeAI Team
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
</pre>
|
||||||
|
|
||||||
|
<h2><a href="https://github.com/Hafiidz/latent-diffusion/blob/main/LICENSE">LDSR</a></h2>
|
||||||
|
<small>Code added by contirubtors, most likely copied from this repository.</small>
|
||||||
|
<pre>
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
</pre>
|
||||||
|
|
||||||
|
<h2><a href="https://github.com/pharmapsychotic/clip-interrogator/blob/main/LICENSE">CLIP Interrogator</a></h2>
|
||||||
|
<small>Some small amounts of code borrowed and reworked.</small>
|
||||||
|
<pre>
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2022 pharmapsychotic
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
</pre>
|
||||||
|
|
||||||
|
<h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
|
||||||
|
<small>Code added by contributors, most likely copied from this repository.</small>
|
||||||
|
|
||||||
|
<pre>
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [2021] [SwinIR Authors]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
</pre>
|
||||||
|
|
||||||
|
<h2><a href="https://github.com/AminRezaei0x443/memory-efficient-attention/blob/main/LICENSE">Memory Efficient Attention</a></h2>
|
||||||
|
<small>The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.</small>
|
||||||
|
<pre>
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023 Alex Birch
|
||||||
|
Copyright (c) 2023 Amin Rezaei
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
</pre>
|
||||||
|
|
||||||
@ -0,0 +1,35 @@
|
|||||||
|
|
||||||
|
function extensions_apply(_, _){
|
||||||
|
disable = []
|
||||||
|
update = []
|
||||||
|
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
||||||
|
if(x.name.startsWith("enable_") && ! x.checked)
|
||||||
|
disable.push(x.name.substr(7))
|
||||||
|
|
||||||
|
if(x.name.startsWith("update_") && x.checked)
|
||||||
|
update.push(x.name.substr(7))
|
||||||
|
})
|
||||||
|
|
||||||
|
restart_reload()
|
||||||
|
|
||||||
|
return [JSON.stringify(disable), JSON.stringify(update)]
|
||||||
|
}
|
||||||
|
|
||||||
|
function extensions_check(){
|
||||||
|
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
||||||
|
x.innerHTML = "Loading..."
|
||||||
|
})
|
||||||
|
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
function install_extension_from_index(button, url){
|
||||||
|
button.disabled = "disabled"
|
||||||
|
button.value = "Installing..."
|
||||||
|
|
||||||
|
textarea = gradioApp().querySelector('#extension_to_install textarea')
|
||||||
|
textarea.value = url
|
||||||
|
textarea.dispatchEvent(new Event("input", { bubbles: true }))
|
||||||
|
|
||||||
|
gradioApp().querySelector('#install_extension_button').click()
|
||||||
|
}
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
|
||||||
|
|
||||||
|
let txt2img_gallery, img2img_gallery, modal = undefined;
|
||||||
|
onUiUpdate(function(){
|
||||||
|
if (!txt2img_gallery) {
|
||||||
|
txt2img_gallery = attachGalleryListeners("txt2img")
|
||||||
|
}
|
||||||
|
if (!img2img_gallery) {
|
||||||
|
img2img_gallery = attachGalleryListeners("img2img")
|
||||||
|
}
|
||||||
|
if (!modal) {
|
||||||
|
modal = gradioApp().getElementById('lightboxModal')
|
||||||
|
modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let modalObserver = new MutationObserver(function(mutations) {
|
||||||
|
mutations.forEach(function(mutationRecord) {
|
||||||
|
let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText
|
||||||
|
if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img')
|
||||||
|
gradioApp().getElementById(selectedTab+"_generation_info_button").click()
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
function attachGalleryListeners(tab_name) {
|
||||||
|
gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
|
||||||
|
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
|
||||||
|
gallery?.addEventListener('keydown', (e) => {
|
||||||
|
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
|
||||||
|
gradioApp().getElementById(tab_name+"_generation_info_button").click()
|
||||||
|
});
|
||||||
|
return gallery;
|
||||||
|
}
|
||||||
@ -0,0 +1,25 @@
|
|||||||
|
|
||||||
|
function setInactive(elem, inactive){
|
||||||
|
console.log(elem)
|
||||||
|
if(inactive){
|
||||||
|
elem.classList.add('inactive')
|
||||||
|
} else{
|
||||||
|
elem.classList.remove('inactive')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
|
||||||
|
console.log(enable, width, height, hr_scale, hr_resize_x, hr_resize_y)
|
||||||
|
|
||||||
|
hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
|
||||||
|
hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
|
||||||
|
hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
|
||||||
|
|
||||||
|
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
|
||||||
|
|
||||||
|
setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0)
|
||||||
|
setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0)
|
||||||
|
setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0)
|
||||||
|
|
||||||
|
return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y]
|
||||||
|
}
|
||||||
@ -1,206 +0,0 @@
|
|||||||
var images_history_click_image = function(){
|
|
||||||
if (!this.classList.contains("transform")){
|
|
||||||
var gallery = images_history_get_parent_by_class(this, "images_history_cantainor");
|
|
||||||
var buttons = gallery.querySelectorAll(".gallery-item");
|
|
||||||
var i = 0;
|
|
||||||
var hidden_list = [];
|
|
||||||
buttons.forEach(function(e){
|
|
||||||
if (e.style.display == "none"){
|
|
||||||
hidden_list.push(i);
|
|
||||||
}
|
|
||||||
i += 1;
|
|
||||||
})
|
|
||||||
if (hidden_list.length > 0){
|
|
||||||
setTimeout(images_history_hide_buttons, 10, hidden_list, gallery);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
images_history_set_image_info(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
var images_history_click_tab = function(){
|
|
||||||
var tabs_box = gradioApp().getElementById("images_history_tab");
|
|
||||||
if (!tabs_box.classList.contains(this.getAttribute("tabname"))) {
|
|
||||||
gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_renew_page").click();
|
|
||||||
tabs_box.classList.add(this.getAttribute("tabname"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function images_history_disabled_del(){
|
|
||||||
gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
|
|
||||||
btn.setAttribute('disabled','disabled');
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
function images_history_get_parent_by_class(item, class_name){
|
|
||||||
var parent = item.parentElement;
|
|
||||||
while(!parent.classList.contains(class_name)){
|
|
||||||
parent = parent.parentElement;
|
|
||||||
}
|
|
||||||
return parent;
|
|
||||||
}
|
|
||||||
|
|
||||||
function images_history_get_parent_by_tagname(item, tagname){
|
|
||||||
var parent = item.parentElement;
|
|
||||||
tagname = tagname.toUpperCase()
|
|
||||||
while(parent.tagName != tagname){
|
|
||||||
console.log(parent.tagName, tagname)
|
|
||||||
parent = parent.parentElement;
|
|
||||||
}
|
|
||||||
return parent;
|
|
||||||
}
|
|
||||||
|
|
||||||
function images_history_hide_buttons(hidden_list, gallery){
|
|
||||||
var buttons = gallery.querySelectorAll(".gallery-item");
|
|
||||||
var num = 0;
|
|
||||||
buttons.forEach(function(e){
|
|
||||||
if (e.style.display == "none"){
|
|
||||||
num += 1;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
if (num == hidden_list.length){
|
|
||||||
setTimeout(images_history_hide_buttons, 10, hidden_list, gallery);
|
|
||||||
}
|
|
||||||
for( i in hidden_list){
|
|
||||||
buttons[hidden_list[i]].style.display = "none";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function images_history_set_image_info(button){
|
|
||||||
var buttons = images_history_get_parent_by_tagname(button, "DIV").querySelectorAll(".gallery-item");
|
|
||||||
var index = -1;
|
|
||||||
var i = 0;
|
|
||||||
buttons.forEach(function(e){
|
|
||||||
if(e == button){
|
|
||||||
index = i;
|
|
||||||
}
|
|
||||||
if(e.style.display != "none"){
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
var gallery = images_history_get_parent_by_class(button, "images_history_cantainor");
|
|
||||||
var set_btn = gallery.querySelector(".images_history_set_index");
|
|
||||||
var curr_idx = set_btn.getAttribute("img_index", index);
|
|
||||||
if (curr_idx != index) {
|
|
||||||
set_btn.setAttribute("img_index", index);
|
|
||||||
images_history_disabled_del();
|
|
||||||
}
|
|
||||||
set_btn.click();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
function images_history_get_current_img(tabname, image_path, files){
|
|
||||||
return [
|
|
||||||
gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"),
|
|
||||||
image_path,
|
|
||||||
files
|
|
||||||
];
|
|
||||||
}
|
|
||||||
|
|
||||||
function images_history_delete(del_num, tabname, img_path, img_file_name, page_index, filenames, image_index){
|
|
||||||
image_index = parseInt(image_index);
|
|
||||||
var tab = gradioApp().getElementById(tabname + '_images_history');
|
|
||||||
var set_btn = tab.querySelector(".images_history_set_index");
|
|
||||||
var buttons = [];
|
|
||||||
tab.querySelectorAll(".gallery-item").forEach(function(e){
|
|
||||||
if (e.style.display != 'none'){
|
|
||||||
buttons.push(e);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
var img_num = buttons.length / 2;
|
|
||||||
if (img_num <= del_num){
|
|
||||||
setTimeout(function(tabname){
|
|
||||||
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
|
|
||||||
}, 30, tabname);
|
|
||||||
} else {
|
|
||||||
var next_img
|
|
||||||
for (var i = 0; i < del_num; i++){
|
|
||||||
if (image_index + i < image_index + img_num){
|
|
||||||
buttons[image_index + i].style.display = 'none';
|
|
||||||
buttons[image_index + img_num + 1].style.display = 'none';
|
|
||||||
next_img = image_index + i + 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var bnt;
|
|
||||||
if (next_img >= img_num){
|
|
||||||
btn = buttons[image_index - del_num];
|
|
||||||
} else {
|
|
||||||
btn = buttons[next_img];
|
|
||||||
}
|
|
||||||
setTimeout(function(btn){btn.click()}, 30, btn);
|
|
||||||
}
|
|
||||||
images_history_disabled_del();
|
|
||||||
return [del_num, tabname, img_path, img_file_name, page_index, filenames, image_index];
|
|
||||||
}
|
|
||||||
|
|
||||||
function images_history_turnpage(img_path, page_index, image_index, tabname){
|
|
||||||
var buttons = gradioApp().getElementById(tabname + '_images_history').querySelectorAll(".gallery-item");
|
|
||||||
buttons.forEach(function(elem) {
|
|
||||||
elem.style.display = 'block';
|
|
||||||
})
|
|
||||||
return [img_path, page_index, image_index, tabname];
|
|
||||||
}
|
|
||||||
|
|
||||||
function images_history_enable_del_buttons(){
|
|
||||||
gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
|
|
||||||
btn.removeAttribute('disabled');
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
function images_history_init(){
|
|
||||||
var load_txt2img_button = gradioApp().getElementById('txt2img_images_history_renew_page')
|
|
||||||
if (load_txt2img_button){
|
|
||||||
for (var i in images_history_tab_list ){
|
|
||||||
tab = images_history_tab_list[i];
|
|
||||||
gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor");
|
|
||||||
gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index");
|
|
||||||
gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button");
|
|
||||||
gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery");
|
|
||||||
|
|
||||||
}
|
|
||||||
var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div");
|
|
||||||
tabs_box.setAttribute("id", "images_history_tab");
|
|
||||||
var tab_btns = tabs_box.querySelectorAll("button");
|
|
||||||
for (var i in images_history_tab_list){
|
|
||||||
var tabname = images_history_tab_list[i]
|
|
||||||
tab_btns[i].setAttribute("tabname", tabname);
|
|
||||||
|
|
||||||
// this refreshes history upon tab switch
|
|
||||||
// until the history is known to work well, which is not the case now, we do not do this at startup
|
|
||||||
//tab_btns[i].addEventListener('click', images_history_click_tab);
|
|
||||||
}
|
|
||||||
tabs_box.classList.add(images_history_tab_list[0]);
|
|
||||||
|
|
||||||
// same as above, at page load
|
|
||||||
//load_txt2img_button.click();
|
|
||||||
} else {
|
|
||||||
setTimeout(images_history_init, 500);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var images_history_tab_list = ["txt2img", "img2img", "extras"];
|
|
||||||
setTimeout(images_history_init, 500);
|
|
||||||
document.addEventListener("DOMContentLoaded", function() {
|
|
||||||
var mutationObserver = new MutationObserver(function(m){
|
|
||||||
for (var i in images_history_tab_list ){
|
|
||||||
let tabname = images_history_tab_list[i]
|
|
||||||
var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item');
|
|
||||||
buttons.forEach(function(bnt){
|
|
||||||
bnt.addEventListener('click', images_history_click_image, true);
|
|
||||||
});
|
|
||||||
|
|
||||||
// same as load_txt2img_button.click() above
|
|
||||||
/*
|
|
||||||
var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg");
|
|
||||||
if (cls_btn){
|
|
||||||
cls_btn.addEventListener('click', function(){
|
|
||||||
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
|
|
||||||
}, false);
|
|
||||||
}*/
|
|
||||||
|
|
||||||
}
|
|
||||||
});
|
|
||||||
mutationObserver.observe( gradioApp(), { childList:true, subtree:true });
|
|
||||||
|
|
||||||
});
|
|
||||||
|
|
||||||
|
|
||||||
Binary file not shown.
@ -0,0 +1,267 @@
|
|||||||
|
import inspect
|
||||||
|
from pydantic import BaseModel, Field, create_model
|
||||||
|
from typing import Any, Optional
|
||||||
|
from typing_extensions import Literal
|
||||||
|
from inflection import underscore
|
||||||
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
||||||
|
from modules.shared import sd_upscalers, opts, parser
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
API_NOT_ALLOWED = [
|
||||||
|
"self",
|
||||||
|
"kwargs",
|
||||||
|
"sd_model",
|
||||||
|
"outpath_samples",
|
||||||
|
"outpath_grids",
|
||||||
|
"sampler_index",
|
||||||
|
"do_not_save_samples",
|
||||||
|
"do_not_save_grid",
|
||||||
|
"extra_generation_params",
|
||||||
|
"overlay_images",
|
||||||
|
"do_not_reload_embeddings",
|
||||||
|
"seed_enable_extras",
|
||||||
|
"prompt_for_display",
|
||||||
|
"sampler_noise_scheduler_override",
|
||||||
|
"ddim_discretize"
|
||||||
|
]
|
||||||
|
|
||||||
|
class ModelDef(BaseModel):
|
||||||
|
"""Assistance Class for Pydantic Dynamic Model Generation"""
|
||||||
|
|
||||||
|
field: str
|
||||||
|
field_alias: str
|
||||||
|
field_type: Any
|
||||||
|
field_value: Any
|
||||||
|
field_exclude: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class PydanticModelGenerator:
|
||||||
|
"""
|
||||||
|
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
|
||||||
|
source_data is a snapshot of the default values produced by the class
|
||||||
|
params are the names of the actual keys required by __init__
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = None,
|
||||||
|
class_instance = None,
|
||||||
|
additional_fields = None,
|
||||||
|
):
|
||||||
|
def field_type_generator(k, v):
|
||||||
|
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
||||||
|
# print(k, v.annotation, v.default)
|
||||||
|
field_type = v.annotation
|
||||||
|
|
||||||
|
return Optional[field_type]
|
||||||
|
|
||||||
|
def merge_class_params(class_):
|
||||||
|
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
|
||||||
|
parameters = {}
|
||||||
|
for classes in all_classes:
|
||||||
|
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
||||||
|
return parameters
|
||||||
|
|
||||||
|
|
||||||
|
self._model_name = model_name
|
||||||
|
self._class_data = merge_class_params(class_instance)
|
||||||
|
|
||||||
|
self._model_def = [
|
||||||
|
ModelDef(
|
||||||
|
field=underscore(k),
|
||||||
|
field_alias=k,
|
||||||
|
field_type=field_type_generator(k, v),
|
||||||
|
field_value=v.default
|
||||||
|
)
|
||||||
|
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||||
|
]
|
||||||
|
|
||||||
|
for fields in additional_fields:
|
||||||
|
self._model_def.append(ModelDef(
|
||||||
|
field=underscore(fields["key"]),
|
||||||
|
field_alias=fields["key"],
|
||||||
|
field_type=fields["type"],
|
||||||
|
field_value=fields["default"],
|
||||||
|
field_exclude=fields["exclude"] if "exclude" in fields else False))
|
||||||
|
|
||||||
|
def generate_model(self):
|
||||||
|
"""
|
||||||
|
Creates a pydantic BaseModel
|
||||||
|
from the json and overrides provided at initialization
|
||||||
|
"""
|
||||||
|
fields = {
|
||||||
|
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
|
||||||
|
}
|
||||||
|
DynamicModel = create_model(self._model_name, **fields)
|
||||||
|
DynamicModel.__config__.allow_population_by_field_name = True
|
||||||
|
DynamicModel.__config__.allow_mutation = True
|
||||||
|
return DynamicModel
|
||||||
|
|
||||||
|
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
||||||
|
"StableDiffusionProcessingTxt2Img",
|
||||||
|
StableDiffusionProcessingTxt2Img,
|
||||||
|
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
|
||||||
|
).generate_model()
|
||||||
|
|
||||||
|
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
||||||
|
"StableDiffusionProcessingImg2Img",
|
||||||
|
StableDiffusionProcessingImg2Img,
|
||||||
|
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
|
||||||
|
).generate_model()
|
||||||
|
|
||||||
|
class TextToImageResponse(BaseModel):
|
||||||
|
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||||
|
parameters: dict
|
||||||
|
info: str
|
||||||
|
|
||||||
|
class ImageToImageResponse(BaseModel):
|
||||||
|
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||||
|
parameters: dict
|
||||||
|
info: str
|
||||||
|
|
||||||
|
class ExtrasBaseRequest(BaseModel):
|
||||||
|
resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
|
||||||
|
show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
|
||||||
|
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
|
||||||
|
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
|
||||||
|
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
|
||||||
|
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.")
|
||||||
|
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
|
||||||
|
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
|
||||||
|
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
|
||||||
|
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
||||||
|
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
||||||
|
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
|
||||||
|
upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?")
|
||||||
|
|
||||||
|
class ExtraBaseResponse(BaseModel):
|
||||||
|
html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
|
||||||
|
|
||||||
|
class ExtrasSingleImageRequest(ExtrasBaseRequest):
|
||||||
|
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||||
|
|
||||||
|
class ExtrasSingleImageResponse(ExtraBaseResponse):
|
||||||
|
image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||||
|
|
||||||
|
class FileData(BaseModel):
|
||||||
|
data: str = Field(title="File data", description="Base64 representation of the file")
|
||||||
|
name: str = Field(title="File name")
|
||||||
|
|
||||||
|
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
|
||||||
|
imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
|
||||||
|
|
||||||
|
class ExtrasBatchImagesResponse(ExtraBaseResponse):
|
||||||
|
images: List[str] = Field(title="Images", description="The generated images in base64 format.")
|
||||||
|
|
||||||
|
class PNGInfoRequest(BaseModel):
|
||||||
|
image: str = Field(title="Image", description="The base64 encoded PNG image")
|
||||||
|
|
||||||
|
class PNGInfoResponse(BaseModel):
|
||||||
|
info: str = Field(title="Image info", description="A string with the parameters used to generate the image")
|
||||||
|
items: dict = Field(title="Items", description="An object containing all the info the image had")
|
||||||
|
|
||||||
|
class ProgressRequest(BaseModel):
|
||||||
|
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
|
||||||
|
|
||||||
|
class ProgressResponse(BaseModel):
|
||||||
|
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
|
||||||
|
eta_relative: float = Field(title="ETA in secs")
|
||||||
|
state: dict = Field(title="State", description="The current state snapshot")
|
||||||
|
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
||||||
|
textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
|
||||||
|
|
||||||
|
class InterrogateRequest(BaseModel):
|
||||||
|
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||||
|
model: str = Field(default="clip", title="Model", description="The interrogate model used.")
|
||||||
|
|
||||||
|
class InterrogateResponse(BaseModel):
|
||||||
|
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
|
||||||
|
|
||||||
|
class TrainResponse(BaseModel):
|
||||||
|
info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.")
|
||||||
|
|
||||||
|
class CreateResponse(BaseModel):
|
||||||
|
info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
|
||||||
|
|
||||||
|
class PreprocessResponse(BaseModel):
|
||||||
|
info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
|
||||||
|
|
||||||
|
fields = {}
|
||||||
|
for key, metadata in opts.data_labels.items():
|
||||||
|
value = opts.data.get(key)
|
||||||
|
optType = opts.typemap.get(type(metadata.default), type(value))
|
||||||
|
|
||||||
|
if (metadata is not None):
|
||||||
|
fields.update({key: (Optional[optType], Field(
|
||||||
|
default=metadata.default ,description=metadata.label))})
|
||||||
|
else:
|
||||||
|
fields.update({key: (Optional[optType], Field())})
|
||||||
|
|
||||||
|
OptionsModel = create_model("Options", **fields)
|
||||||
|
|
||||||
|
flags = {}
|
||||||
|
_options = vars(parser)['_option_string_actions']
|
||||||
|
for key in _options:
|
||||||
|
if(_options[key].dest != 'help'):
|
||||||
|
flag = _options[key]
|
||||||
|
_type = str
|
||||||
|
if _options[key].default is not None: _type = type(_options[key].default)
|
||||||
|
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
|
||||||
|
|
||||||
|
FlagsModel = create_model("Flags", **flags)
|
||||||
|
|
||||||
|
class SamplerItem(BaseModel):
|
||||||
|
name: str = Field(title="Name")
|
||||||
|
aliases: List[str] = Field(title="Aliases")
|
||||||
|
options: Dict[str, str] = Field(title="Options")
|
||||||
|
|
||||||
|
class UpscalerItem(BaseModel):
|
||||||
|
name: str = Field(title="Name")
|
||||||
|
model_name: Optional[str] = Field(title="Model Name")
|
||||||
|
model_path: Optional[str] = Field(title="Path")
|
||||||
|
model_url: Optional[str] = Field(title="URL")
|
||||||
|
|
||||||
|
class SDModelItem(BaseModel):
|
||||||
|
title: str = Field(title="Title")
|
||||||
|
model_name: str = Field(title="Model Name")
|
||||||
|
hash: str = Field(title="Hash")
|
||||||
|
filename: str = Field(title="Filename")
|
||||||
|
config: str = Field(title="Config file")
|
||||||
|
|
||||||
|
class HypernetworkItem(BaseModel):
|
||||||
|
name: str = Field(title="Name")
|
||||||
|
path: Optional[str] = Field(title="Path")
|
||||||
|
|
||||||
|
class FaceRestorerItem(BaseModel):
|
||||||
|
name: str = Field(title="Name")
|
||||||
|
cmd_dir: Optional[str] = Field(title="Path")
|
||||||
|
|
||||||
|
class RealesrganItem(BaseModel):
|
||||||
|
name: str = Field(title="Name")
|
||||||
|
path: Optional[str] = Field(title="Path")
|
||||||
|
scale: Optional[int] = Field(title="Scale")
|
||||||
|
|
||||||
|
class PromptStyleItem(BaseModel):
|
||||||
|
name: str = Field(title="Name")
|
||||||
|
prompt: Optional[str] = Field(title="Prompt")
|
||||||
|
negative_prompt: Optional[str] = Field(title="Negative Prompt")
|
||||||
|
|
||||||
|
class ArtistItem(BaseModel):
|
||||||
|
name: str = Field(title="Name")
|
||||||
|
score: float = Field(title="Score")
|
||||||
|
category: str = Field(title="Category")
|
||||||
|
|
||||||
|
class EmbeddingItem(BaseModel):
|
||||||
|
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
|
||||||
|
sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
|
||||||
|
sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
|
||||||
|
shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
|
||||||
|
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
|
||||||
|
|
||||||
|
class EmbeddingsResponse(BaseModel):
|
||||||
|
loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
|
||||||
|
skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
|
||||||
|
|
||||||
|
class MemoryResponse(BaseModel):
|
||||||
|
ram: dict = Field(title="RAM", description="System memory stats")
|
||||||
|
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
|
||||||
@ -1,99 +0,0 @@
|
|||||||
from inflection import underscore
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
from pydantic import BaseModel, Field, create_model
|
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
|
|
||||||
API_NOT_ALLOWED = [
|
|
||||||
"self",
|
|
||||||
"kwargs",
|
|
||||||
"sd_model",
|
|
||||||
"outpath_samples",
|
|
||||||
"outpath_grids",
|
|
||||||
"sampler_index",
|
|
||||||
"do_not_save_samples",
|
|
||||||
"do_not_save_grid",
|
|
||||||
"extra_generation_params",
|
|
||||||
"overlay_images",
|
|
||||||
"do_not_reload_embeddings",
|
|
||||||
"seed_enable_extras",
|
|
||||||
"prompt_for_display",
|
|
||||||
"sampler_noise_scheduler_override",
|
|
||||||
"ddim_discretize"
|
|
||||||
]
|
|
||||||
|
|
||||||
class ModelDef(BaseModel):
|
|
||||||
"""Assistance Class for Pydantic Dynamic Model Generation"""
|
|
||||||
|
|
||||||
field: str
|
|
||||||
field_alias: str
|
|
||||||
field_type: Any
|
|
||||||
field_value: Any
|
|
||||||
|
|
||||||
|
|
||||||
class PydanticModelGenerator:
|
|
||||||
"""
|
|
||||||
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
|
|
||||||
source_data is a snapshot of the default values produced by the class
|
|
||||||
params are the names of the actual keys required by __init__
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str = None,
|
|
||||||
class_instance = None,
|
|
||||||
additional_fields = None,
|
|
||||||
):
|
|
||||||
def field_type_generator(k, v):
|
|
||||||
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
|
||||||
# print(k, v.annotation, v.default)
|
|
||||||
field_type = v.annotation
|
|
||||||
|
|
||||||
return Optional[field_type]
|
|
||||||
|
|
||||||
def merge_class_params(class_):
|
|
||||||
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
|
|
||||||
parameters = {}
|
|
||||||
for classes in all_classes:
|
|
||||||
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
|
||||||
return parameters
|
|
||||||
|
|
||||||
|
|
||||||
self._model_name = model_name
|
|
||||||
self._class_data = merge_class_params(class_instance)
|
|
||||||
self._model_def = [
|
|
||||||
ModelDef(
|
|
||||||
field=underscore(k),
|
|
||||||
field_alias=k,
|
|
||||||
field_type=field_type_generator(k, v),
|
|
||||||
field_value=v.default
|
|
||||||
)
|
|
||||||
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
|
||||||
]
|
|
||||||
|
|
||||||
for fields in additional_fields:
|
|
||||||
self._model_def.append(ModelDef(
|
|
||||||
field=underscore(fields["key"]),
|
|
||||||
field_alias=fields["key"],
|
|
||||||
field_type=fields["type"],
|
|
||||||
field_value=fields["default"]))
|
|
||||||
|
|
||||||
def generate_model(self):
|
|
||||||
"""
|
|
||||||
Creates a pydantic BaseModel
|
|
||||||
from the json and overrides provided at initialization
|
|
||||||
"""
|
|
||||||
fields = {
|
|
||||||
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
|
|
||||||
}
|
|
||||||
DynamicModel = create_model(self._model_name, **fields)
|
|
||||||
DynamicModel.__config__.allow_population_by_field_name = True
|
|
||||||
DynamicModel.__config__.allow_mutation = True
|
|
||||||
return DynamicModel
|
|
||||||
|
|
||||||
StableDiffusionProcessingAPI = PydanticModelGenerator(
|
|
||||||
"StableDiffusionProcessingTxt2Img",
|
|
||||||
StableDiffusionProcessingTxt2Img,
|
|
||||||
[{"key": "sampler_index", "type": str, "default": "Euler"}]
|
|
||||||
).generate_model()
|
|
||||||
@ -1,76 +0,0 @@
|
|||||||
import os.path
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import PIL.Image
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
|
|
||||||
import modules.upscaler
|
|
||||||
from modules import devices, modelloader
|
|
||||||
from modules.bsrgan_model_arch import RRDBNet
|
|
||||||
|
|
||||||
|
|
||||||
class UpscalerBSRGAN(modules.upscaler.Upscaler):
|
|
||||||
def __init__(self, dirname):
|
|
||||||
self.name = "BSRGAN"
|
|
||||||
self.model_name = "BSRGAN 4x"
|
|
||||||
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
|
|
||||||
self.user_path = dirname
|
|
||||||
super().__init__()
|
|
||||||
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
|
||||||
scalers = []
|
|
||||||
if len(model_paths) == 0:
|
|
||||||
scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
|
|
||||||
scalers.append(scaler_data)
|
|
||||||
for file in model_paths:
|
|
||||||
if "http" in file:
|
|
||||||
name = self.model_name
|
|
||||||
else:
|
|
||||||
name = modelloader.friendly_name(file)
|
|
||||||
try:
|
|
||||||
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
|
||||||
scalers.append(scaler_data)
|
|
||||||
except Exception:
|
|
||||||
print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
self.scalers = scalers
|
|
||||||
|
|
||||||
def do_upscale(self, img: PIL.Image, selected_file):
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
model = self.load_model(selected_file)
|
|
||||||
if model is None:
|
|
||||||
return img
|
|
||||||
model.to(devices.device_bsrgan)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
img = np.array(img)
|
|
||||||
img = img[:, :, ::-1]
|
|
||||||
img = np.moveaxis(img, 2, 0) / 255
|
|
||||||
img = torch.from_numpy(img).float()
|
|
||||||
img = img.unsqueeze(0).to(devices.device_bsrgan)
|
|
||||||
with torch.no_grad():
|
|
||||||
output = model(img)
|
|
||||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
||||||
output = 255. * np.moveaxis(output, 0, 2)
|
|
||||||
output = output.astype(np.uint8)
|
|
||||||
output = output[:, :, ::-1]
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return PIL.Image.fromarray(output, 'RGB')
|
|
||||||
|
|
||||||
def load_model(self, path: str):
|
|
||||||
if "http" in path:
|
|
||||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
|
||||||
progress=True)
|
|
||||||
else:
|
|
||||||
filename = path
|
|
||||||
if not os.path.exists(filename) or filename is None:
|
|
||||||
print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
|
|
||||||
return None
|
|
||||||
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
|
|
||||||
model.load_state_dict(torch.load(filename), strict=True)
|
|
||||||
model.eval()
|
|
||||||
for k, v in model.named_parameters():
|
|
||||||
v.requires_grad = False
|
|
||||||
return model
|
|
||||||
|
|
||||||
@ -1,102 +0,0 @@
|
|||||||
import functools
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.nn.init as init
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_weights(net_l, scale=1):
|
|
||||||
if not isinstance(net_l, list):
|
|
||||||
net_l = [net_l]
|
|
||||||
for net in net_l:
|
|
||||||
for m in net.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
|
||||||
m.weight.data *= scale # for residual block
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.zero_()
|
|
||||||
elif isinstance(m, nn.Linear):
|
|
||||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
|
||||||
m.weight.data *= scale
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.zero_()
|
|
||||||
elif isinstance(m, nn.BatchNorm2d):
|
|
||||||
init.constant_(m.weight, 1)
|
|
||||||
init.constant_(m.bias.data, 0.0)
|
|
||||||
|
|
||||||
|
|
||||||
def make_layer(block, n_layers):
|
|
||||||
layers = []
|
|
||||||
for _ in range(n_layers):
|
|
||||||
layers.append(block())
|
|
||||||
return nn.Sequential(*layers)
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualDenseBlock_5C(nn.Module):
|
|
||||||
def __init__(self, nf=64, gc=32, bias=True):
|
|
||||||
super(ResidualDenseBlock_5C, self).__init__()
|
|
||||||
# gc: growth channel, i.e. intermediate channels
|
|
||||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
# initialization
|
|
||||||
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x1 = self.lrelu(self.conv1(x))
|
|
||||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
|
||||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
|
||||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
|
||||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
|
||||||
return x5 * 0.2 + x
|
|
||||||
|
|
||||||
|
|
||||||
class RRDB(nn.Module):
|
|
||||||
'''Residual in Residual Dense Block'''
|
|
||||||
|
|
||||||
def __init__(self, nf, gc=32):
|
|
||||||
super(RRDB, self).__init__()
|
|
||||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
|
||||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
|
||||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = self.RDB1(x)
|
|
||||||
out = self.RDB2(out)
|
|
||||||
out = self.RDB3(out)
|
|
||||||
return out * 0.2 + x
|
|
||||||
|
|
||||||
|
|
||||||
class RRDBNet(nn.Module):
|
|
||||||
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
|
|
||||||
super(RRDBNet, self).__init__()
|
|
||||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
|
||||||
self.sf = sf
|
|
||||||
|
|
||||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
|
||||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
|
||||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
#### upsampling
|
|
||||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
if self.sf==4:
|
|
||||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
|
||||||
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
fea = self.conv_first(x)
|
|
||||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
|
||||||
fea = fea + trunk
|
|
||||||
|
|
||||||
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
|
||||||
if self.sf==4:
|
|
||||||
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
|
||||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
|
||||||
|
|
||||||
return out
|
|
||||||
@ -0,0 +1,98 @@
|
|||||||
|
import html
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import traceback
|
||||||
|
import time
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
queue_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_queued_call(func):
|
||||||
|
def f(*args, **kwargs):
|
||||||
|
with queue_lock:
|
||||||
|
res = func(*args, **kwargs)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||||
|
def f(*args, **kwargs):
|
||||||
|
|
||||||
|
shared.state.begin()
|
||||||
|
|
||||||
|
with queue_lock:
|
||||||
|
res = func(*args, **kwargs)
|
||||||
|
|
||||||
|
shared.state.end()
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
||||||
|
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
||||||
|
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
|
||||||
|
if run_memmon:
|
||||||
|
shared.mem_mon.monitor()
|
||||||
|
t = time.perf_counter()
|
||||||
|
|
||||||
|
try:
|
||||||
|
res = list(func(*args, **kwargs))
|
||||||
|
except Exception as e:
|
||||||
|
# When printing out our debug argument list, do not print out more than a MB of text
|
||||||
|
max_debug_str_len = 131072 # (1024*1024)/8
|
||||||
|
|
||||||
|
print("Error completing request", file=sys.stderr)
|
||||||
|
argStr = f"Arguments: {str(args)} {str(kwargs)}"
|
||||||
|
print(argStr[:max_debug_str_len], file=sys.stderr)
|
||||||
|
if len(argStr) > max_debug_str_len:
|
||||||
|
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
||||||
|
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
shared.state.job = ""
|
||||||
|
shared.state.job_count = 0
|
||||||
|
|
||||||
|
if extra_outputs_array is None:
|
||||||
|
extra_outputs_array = [None, '']
|
||||||
|
|
||||||
|
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
|
||||||
|
|
||||||
|
shared.state.skipped = False
|
||||||
|
shared.state.interrupted = False
|
||||||
|
shared.state.job_count = 0
|
||||||
|
|
||||||
|
if not add_stats:
|
||||||
|
return tuple(res)
|
||||||
|
|
||||||
|
elapsed = time.perf_counter() - t
|
||||||
|
elapsed_m = int(elapsed // 60)
|
||||||
|
elapsed_s = elapsed % 60
|
||||||
|
elapsed_text = f"{elapsed_s:.2f}s"
|
||||||
|
if elapsed_m > 0:
|
||||||
|
elapsed_text = f"{elapsed_m}m "+elapsed_text
|
||||||
|
|
||||||
|
if run_memmon:
|
||||||
|
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
|
||||||
|
active_peak = mem_stats['active_peak']
|
||||||
|
reserved_peak = mem_stats['reserved_peak']
|
||||||
|
sys_peak = mem_stats['system_peak']
|
||||||
|
sys_total = mem_stats['total']
|
||||||
|
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
|
||||||
|
|
||||||
|
vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
|
||||||
|
else:
|
||||||
|
vram_html = ''
|
||||||
|
|
||||||
|
# last item is always HTML
|
||||||
|
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
|
||||||
|
|
||||||
|
return tuple(res)
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
@ -1,172 +1,99 @@
|
|||||||
import os.path
|
import os
|
||||||
from concurrent.futures import ProcessPoolExecutor
|
|
||||||
import multiprocessing
|
|
||||||
import time
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from modules import modelloader, paths, deepbooru_model, devices, images, shared
|
||||||
|
|
||||||
re_special = re.compile(r'([\\()])')
|
re_special = re.compile(r'([\\()])')
|
||||||
|
|
||||||
def get_deepbooru_tags(pil_image):
|
|
||||||
"""
|
class DeepDanbooru:
|
||||||
This method is for running only one image at a time for simple use. Used to the img2img interrogate.
|
def __init__(self):
|
||||||
"""
|
self.model = None
|
||||||
from modules import shared # prevents circular reference
|
|
||||||
|
def load(self):
|
||||||
try:
|
if self.model is not None:
|
||||||
create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts())
|
return
|
||||||
return get_tags_from_process(pil_image)
|
|
||||||
finally:
|
files = modelloader.load_models(
|
||||||
release_process()
|
model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
|
||||||
|
model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
|
||||||
|
ext_filter=[".pt"],
|
||||||
OPT_INCLUDE_RANKS = "include_ranks"
|
download_name='model-resnet_custom_v3.pt',
|
||||||
def create_deepbooru_opts():
|
)
|
||||||
from modules import shared
|
|
||||||
|
self.model = deepbooru_model.DeepDanbooruModel()
|
||||||
return {
|
self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
|
||||||
"use_spaces": shared.opts.deepbooru_use_spaces,
|
|
||||||
"use_escape": shared.opts.deepbooru_escape,
|
self.model.eval()
|
||||||
"alpha_sort": shared.opts.deepbooru_sort_alpha,
|
self.model.to(devices.cpu, devices.dtype)
|
||||||
OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks,
|
|
||||||
}
|
def start(self):
|
||||||
|
self.load()
|
||||||
|
self.model.to(devices.device)
|
||||||
def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts):
|
|
||||||
model, tags = get_deepbooru_tags_model()
|
def stop(self):
|
||||||
while True: # while process is running, keep monitoring queue for new image
|
if not shared.opts.interrogate_keep_models_in_memory:
|
||||||
pil_image = queue.get()
|
self.model.to(devices.cpu)
|
||||||
if pil_image == "QUIT":
|
devices.torch_gc()
|
||||||
break
|
|
||||||
else:
|
def tag(self, pil_image):
|
||||||
deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts)
|
self.start()
|
||||||
|
res = self.tag_multi(pil_image)
|
||||||
|
self.stop()
|
||||||
def create_deepbooru_process(threshold, deepbooru_opts):
|
|
||||||
"""
|
return res
|
||||||
Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
|
|
||||||
to be processed in a row without reloading the model or creating a new process. To return the data, a shared
|
def tag_multi(self, pil_image, force_disable_ranks=False):
|
||||||
dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned
|
threshold = shared.opts.interrogate_deepbooru_score_threshold
|
||||||
to the dictionary and the method adding the image to the queue should wait for this value to be updated with
|
use_spaces = shared.opts.deepbooru_use_spaces
|
||||||
the tags.
|
use_escape = shared.opts.deepbooru_escape
|
||||||
"""
|
alpha_sort = shared.opts.deepbooru_sort_alpha
|
||||||
from modules import shared # prevents circular reference
|
include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
|
||||||
shared.deepbooru_process_manager = multiprocessing.Manager()
|
|
||||||
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
|
pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
|
||||||
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
|
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
|
||||||
shared.deepbooru_process_return["value"] = -1
|
|
||||||
shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
|
with torch.no_grad(), devices.autocast():
|
||||||
shared.deepbooru_process.start()
|
x = torch.from_numpy(a).to(devices.device)
|
||||||
|
y = self.model(x)[0].detach().cpu().numpy()
|
||||||
|
|
||||||
def get_tags_from_process(image):
|
probability_dict = {}
|
||||||
from modules import shared
|
|
||||||
|
for tag, probability in zip(self.model.tags, y):
|
||||||
shared.deepbooru_process_return["value"] = -1
|
if probability < threshold:
|
||||||
shared.deepbooru_process_queue.put(image)
|
continue
|
||||||
while shared.deepbooru_process_return["value"] == -1:
|
|
||||||
time.sleep(0.2)
|
|
||||||
caption = shared.deepbooru_process_return["value"]
|
|
||||||
shared.deepbooru_process_return["value"] = -1
|
|
||||||
|
|
||||||
return caption
|
|
||||||
|
|
||||||
|
|
||||||
def release_process():
|
|
||||||
"""
|
|
||||||
Stops the deepbooru process to return used memory
|
|
||||||
"""
|
|
||||||
from modules import shared # prevents circular reference
|
|
||||||
shared.deepbooru_process_queue.put("QUIT")
|
|
||||||
shared.deepbooru_process.join()
|
|
||||||
shared.deepbooru_process_queue = None
|
|
||||||
shared.deepbooru_process = None
|
|
||||||
shared.deepbooru_process_return = None
|
|
||||||
shared.deepbooru_process_manager = None
|
|
||||||
|
|
||||||
def get_deepbooru_tags_model():
|
|
||||||
import deepdanbooru as dd
|
|
||||||
import tensorflow as tf
|
|
||||||
import numpy as np
|
|
||||||
this_folder = os.path.dirname(__file__)
|
|
||||||
model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
|
|
||||||
if not os.path.exists(os.path.join(model_path, 'project.json')):
|
|
||||||
# there is no point importing these every time
|
|
||||||
import zipfile
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
load_file_from_url(
|
|
||||||
r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
|
|
||||||
model_path)
|
|
||||||
with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
|
|
||||||
zip_ref.extractall(model_path)
|
|
||||||
os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
|
|
||||||
|
|
||||||
tags = dd.project.load_tags_from_project(model_path)
|
|
||||||
model = dd.project.load_model_from_project(
|
|
||||||
model_path, compile_model=False
|
|
||||||
)
|
|
||||||
return model, tags
|
|
||||||
|
|
||||||
|
|
||||||
def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts):
|
|
||||||
import deepdanbooru as dd
|
|
||||||
import tensorflow as tf
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
alpha_sort = deepbooru_opts['alpha_sort']
|
|
||||||
use_spaces = deepbooru_opts['use_spaces']
|
|
||||||
use_escape = deepbooru_opts['use_escape']
|
|
||||||
include_ranks = deepbooru_opts['include_ranks']
|
|
||||||
|
|
||||||
width = model.input_shape[2]
|
|
||||||
height = model.input_shape[1]
|
|
||||||
image = np.array(pil_image)
|
|
||||||
image = tf.image.resize(
|
|
||||||
image,
|
|
||||||
size=(height, width),
|
|
||||||
method=tf.image.ResizeMethod.AREA,
|
|
||||||
preserve_aspect_ratio=True,
|
|
||||||
)
|
|
||||||
image = image.numpy() # EagerTensor to np.array
|
|
||||||
image = dd.image.transform_and_pad_image(image, width, height)
|
|
||||||
image = image / 255.0
|
|
||||||
image_shape = image.shape
|
|
||||||
image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
|
|
||||||
|
|
||||||
y = model.predict(image)[0]
|
|
||||||
|
|
||||||
result_dict = {}
|
|
||||||
|
|
||||||
for i, tag in enumerate(tags):
|
|
||||||
result_dict[tag] = y[i]
|
|
||||||
|
|
||||||
unsorted_tags_in_theshold = []
|
|
||||||
result_tags_print = []
|
|
||||||
for tag in tags:
|
|
||||||
if result_dict[tag] >= threshold:
|
|
||||||
if tag.startswith("rating:"):
|
if tag.startswith("rating:"):
|
||||||
continue
|
continue
|
||||||
unsorted_tags_in_theshold.append((result_dict[tag], tag))
|
|
||||||
result_tags_print.append(f'{result_dict[tag]} {tag}')
|
probability_dict[tag] = probability
|
||||||
|
|
||||||
# sort tags
|
if alpha_sort:
|
||||||
result_tags_out = []
|
tags = sorted(probability_dict)
|
||||||
sort_ndx = 0
|
else:
|
||||||
if alpha_sort:
|
tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
|
||||||
sort_ndx = 1
|
|
||||||
|
res = []
|
||||||
# sort by reverse by likelihood and normal for alpha, and format tag text as requested
|
|
||||||
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
|
filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
|
||||||
for weight, tag in unsorted_tags_in_theshold:
|
|
||||||
tag_outformat = tag
|
for tag in [x for x in tags if x not in filtertags]:
|
||||||
if use_spaces:
|
probability = probability_dict[tag]
|
||||||
tag_outformat = tag_outformat.replace('_', ' ')
|
tag_outformat = tag
|
||||||
if use_escape:
|
if use_spaces:
|
||||||
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
|
tag_outformat = tag_outformat.replace('_', ' ')
|
||||||
if include_ranks:
|
if use_escape:
|
||||||
tag_outformat = f"({tag_outformat}:{weight:.3f})"
|
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
|
||||||
|
if include_ranks:
|
||||||
result_tags_out.append(tag_outformat)
|
tag_outformat = f"({tag_outformat}:{probability:.3f})"
|
||||||
|
|
||||||
print('\n'.join(sorted(result_tags_print, reverse=True)))
|
res.append(tag_outformat)
|
||||||
|
|
||||||
return ', '.join(result_tags_out)
|
return ", ".join(res)
|
||||||
|
|
||||||
|
|
||||||
|
model = DeepDanbooru()
|
||||||
|
|||||||
@ -0,0 +1,676 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
|
||||||
|
|
||||||
|
|
||||||
|
class DeepDanbooruModel(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(DeepDanbooruModel, self).__init__()
|
||||||
|
|
||||||
|
self.tags = []
|
||||||
|
|
||||||
|
self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
|
||||||
|
self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
|
||||||
|
self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||||
|
self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
|
||||||
|
self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
||||||
|
self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||||
|
self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
||||||
|
self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
||||||
|
self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||||
|
self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
||||||
|
self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
||||||
|
self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||||
|
self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
|
||||||
|
self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
|
||||||
|
self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
|
||||||
|
self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||||
|
self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||||
|
self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||||
|
self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||||
|
self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||||
|
self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||||
|
self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||||
|
self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||||
|
self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||||
|
self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||||
|
self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||||
|
self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||||
|
self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||||
|
self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||||
|
self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||||
|
self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||||
|
self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||||
|
self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||||
|
self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||||
|
self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||||
|
self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||||
|
self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||||
|
self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
|
||||||
|
self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
|
||||||
|
self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
||||||
|
self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
||||||
|
self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
|
||||||
|
self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||||
|
self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||||
|
self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||||
|
self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
|
||||||
|
self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
|
||||||
|
self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
|
||||||
|
self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
||||||
|
self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
||||||
|
self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
||||||
|
self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
||||||
|
self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
||||||
|
self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
||||||
|
self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
||||||
|
self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
|
||||||
|
self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
|
||||||
|
self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
|
||||||
|
self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
||||||
|
self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
||||||
|
self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
||||||
|
self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
||||||
|
self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
||||||
|
self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
||||||
|
self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
||||||
|
self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
|
||||||
|
|
||||||
|
def forward(self, *inputs):
|
||||||
|
t_358, = inputs
|
||||||
|
t_359 = t_358.permute(*[0, 3, 1, 2])
|
||||||
|
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
|
||||||
|
t_360 = self.n_Conv_0(t_359_padded)
|
||||||
|
t_361 = F.relu(t_360)
|
||||||
|
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
|
||||||
|
t_362 = self.n_MaxPool_0(t_361)
|
||||||
|
t_363 = self.n_Conv_1(t_362)
|
||||||
|
t_364 = self.n_Conv_2(t_362)
|
||||||
|
t_365 = F.relu(t_364)
|
||||||
|
t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
|
||||||
|
t_366 = self.n_Conv_3(t_365_padded)
|
||||||
|
t_367 = F.relu(t_366)
|
||||||
|
t_368 = self.n_Conv_4(t_367)
|
||||||
|
t_369 = torch.add(t_368, t_363)
|
||||||
|
t_370 = F.relu(t_369)
|
||||||
|
t_371 = self.n_Conv_5(t_370)
|
||||||
|
t_372 = F.relu(t_371)
|
||||||
|
t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
|
||||||
|
t_373 = self.n_Conv_6(t_372_padded)
|
||||||
|
t_374 = F.relu(t_373)
|
||||||
|
t_375 = self.n_Conv_7(t_374)
|
||||||
|
t_376 = torch.add(t_375, t_370)
|
||||||
|
t_377 = F.relu(t_376)
|
||||||
|
t_378 = self.n_Conv_8(t_377)
|
||||||
|
t_379 = F.relu(t_378)
|
||||||
|
t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
|
||||||
|
t_380 = self.n_Conv_9(t_379_padded)
|
||||||
|
t_381 = F.relu(t_380)
|
||||||
|
t_382 = self.n_Conv_10(t_381)
|
||||||
|
t_383 = torch.add(t_382, t_377)
|
||||||
|
t_384 = F.relu(t_383)
|
||||||
|
t_385 = self.n_Conv_11(t_384)
|
||||||
|
t_386 = self.n_Conv_12(t_384)
|
||||||
|
t_387 = F.relu(t_386)
|
||||||
|
t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
|
||||||
|
t_388 = self.n_Conv_13(t_387_padded)
|
||||||
|
t_389 = F.relu(t_388)
|
||||||
|
t_390 = self.n_Conv_14(t_389)
|
||||||
|
t_391 = torch.add(t_390, t_385)
|
||||||
|
t_392 = F.relu(t_391)
|
||||||
|
t_393 = self.n_Conv_15(t_392)
|
||||||
|
t_394 = F.relu(t_393)
|
||||||
|
t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
|
||||||
|
t_395 = self.n_Conv_16(t_394_padded)
|
||||||
|
t_396 = F.relu(t_395)
|
||||||
|
t_397 = self.n_Conv_17(t_396)
|
||||||
|
t_398 = torch.add(t_397, t_392)
|
||||||
|
t_399 = F.relu(t_398)
|
||||||
|
t_400 = self.n_Conv_18(t_399)
|
||||||
|
t_401 = F.relu(t_400)
|
||||||
|
t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
|
||||||
|
t_402 = self.n_Conv_19(t_401_padded)
|
||||||
|
t_403 = F.relu(t_402)
|
||||||
|
t_404 = self.n_Conv_20(t_403)
|
||||||
|
t_405 = torch.add(t_404, t_399)
|
||||||
|
t_406 = F.relu(t_405)
|
||||||
|
t_407 = self.n_Conv_21(t_406)
|
||||||
|
t_408 = F.relu(t_407)
|
||||||
|
t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
|
||||||
|
t_409 = self.n_Conv_22(t_408_padded)
|
||||||
|
t_410 = F.relu(t_409)
|
||||||
|
t_411 = self.n_Conv_23(t_410)
|
||||||
|
t_412 = torch.add(t_411, t_406)
|
||||||
|
t_413 = F.relu(t_412)
|
||||||
|
t_414 = self.n_Conv_24(t_413)
|
||||||
|
t_415 = F.relu(t_414)
|
||||||
|
t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
|
||||||
|
t_416 = self.n_Conv_25(t_415_padded)
|
||||||
|
t_417 = F.relu(t_416)
|
||||||
|
t_418 = self.n_Conv_26(t_417)
|
||||||
|
t_419 = torch.add(t_418, t_413)
|
||||||
|
t_420 = F.relu(t_419)
|
||||||
|
t_421 = self.n_Conv_27(t_420)
|
||||||
|
t_422 = F.relu(t_421)
|
||||||
|
t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
|
||||||
|
t_423 = self.n_Conv_28(t_422_padded)
|
||||||
|
t_424 = F.relu(t_423)
|
||||||
|
t_425 = self.n_Conv_29(t_424)
|
||||||
|
t_426 = torch.add(t_425, t_420)
|
||||||
|
t_427 = F.relu(t_426)
|
||||||
|
t_428 = self.n_Conv_30(t_427)
|
||||||
|
t_429 = F.relu(t_428)
|
||||||
|
t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
|
||||||
|
t_430 = self.n_Conv_31(t_429_padded)
|
||||||
|
t_431 = F.relu(t_430)
|
||||||
|
t_432 = self.n_Conv_32(t_431)
|
||||||
|
t_433 = torch.add(t_432, t_427)
|
||||||
|
t_434 = F.relu(t_433)
|
||||||
|
t_435 = self.n_Conv_33(t_434)
|
||||||
|
t_436 = F.relu(t_435)
|
||||||
|
t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
|
||||||
|
t_437 = self.n_Conv_34(t_436_padded)
|
||||||
|
t_438 = F.relu(t_437)
|
||||||
|
t_439 = self.n_Conv_35(t_438)
|
||||||
|
t_440 = torch.add(t_439, t_434)
|
||||||
|
t_441 = F.relu(t_440)
|
||||||
|
t_442 = self.n_Conv_36(t_441)
|
||||||
|
t_443 = self.n_Conv_37(t_441)
|
||||||
|
t_444 = F.relu(t_443)
|
||||||
|
t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
|
||||||
|
t_445 = self.n_Conv_38(t_444_padded)
|
||||||
|
t_446 = F.relu(t_445)
|
||||||
|
t_447 = self.n_Conv_39(t_446)
|
||||||
|
t_448 = torch.add(t_447, t_442)
|
||||||
|
t_449 = F.relu(t_448)
|
||||||
|
t_450 = self.n_Conv_40(t_449)
|
||||||
|
t_451 = F.relu(t_450)
|
||||||
|
t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
|
||||||
|
t_452 = self.n_Conv_41(t_451_padded)
|
||||||
|
t_453 = F.relu(t_452)
|
||||||
|
t_454 = self.n_Conv_42(t_453)
|
||||||
|
t_455 = torch.add(t_454, t_449)
|
||||||
|
t_456 = F.relu(t_455)
|
||||||
|
t_457 = self.n_Conv_43(t_456)
|
||||||
|
t_458 = F.relu(t_457)
|
||||||
|
t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
|
||||||
|
t_459 = self.n_Conv_44(t_458_padded)
|
||||||
|
t_460 = F.relu(t_459)
|
||||||
|
t_461 = self.n_Conv_45(t_460)
|
||||||
|
t_462 = torch.add(t_461, t_456)
|
||||||
|
t_463 = F.relu(t_462)
|
||||||
|
t_464 = self.n_Conv_46(t_463)
|
||||||
|
t_465 = F.relu(t_464)
|
||||||
|
t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
|
||||||
|
t_466 = self.n_Conv_47(t_465_padded)
|
||||||
|
t_467 = F.relu(t_466)
|
||||||
|
t_468 = self.n_Conv_48(t_467)
|
||||||
|
t_469 = torch.add(t_468, t_463)
|
||||||
|
t_470 = F.relu(t_469)
|
||||||
|
t_471 = self.n_Conv_49(t_470)
|
||||||
|
t_472 = F.relu(t_471)
|
||||||
|
t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
|
||||||
|
t_473 = self.n_Conv_50(t_472_padded)
|
||||||
|
t_474 = F.relu(t_473)
|
||||||
|
t_475 = self.n_Conv_51(t_474)
|
||||||
|
t_476 = torch.add(t_475, t_470)
|
||||||
|
t_477 = F.relu(t_476)
|
||||||
|
t_478 = self.n_Conv_52(t_477)
|
||||||
|
t_479 = F.relu(t_478)
|
||||||
|
t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
|
||||||
|
t_480 = self.n_Conv_53(t_479_padded)
|
||||||
|
t_481 = F.relu(t_480)
|
||||||
|
t_482 = self.n_Conv_54(t_481)
|
||||||
|
t_483 = torch.add(t_482, t_477)
|
||||||
|
t_484 = F.relu(t_483)
|
||||||
|
t_485 = self.n_Conv_55(t_484)
|
||||||
|
t_486 = F.relu(t_485)
|
||||||
|
t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
|
||||||
|
t_487 = self.n_Conv_56(t_486_padded)
|
||||||
|
t_488 = F.relu(t_487)
|
||||||
|
t_489 = self.n_Conv_57(t_488)
|
||||||
|
t_490 = torch.add(t_489, t_484)
|
||||||
|
t_491 = F.relu(t_490)
|
||||||
|
t_492 = self.n_Conv_58(t_491)
|
||||||
|
t_493 = F.relu(t_492)
|
||||||
|
t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
|
||||||
|
t_494 = self.n_Conv_59(t_493_padded)
|
||||||
|
t_495 = F.relu(t_494)
|
||||||
|
t_496 = self.n_Conv_60(t_495)
|
||||||
|
t_497 = torch.add(t_496, t_491)
|
||||||
|
t_498 = F.relu(t_497)
|
||||||
|
t_499 = self.n_Conv_61(t_498)
|
||||||
|
t_500 = F.relu(t_499)
|
||||||
|
t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
|
||||||
|
t_501 = self.n_Conv_62(t_500_padded)
|
||||||
|
t_502 = F.relu(t_501)
|
||||||
|
t_503 = self.n_Conv_63(t_502)
|
||||||
|
t_504 = torch.add(t_503, t_498)
|
||||||
|
t_505 = F.relu(t_504)
|
||||||
|
t_506 = self.n_Conv_64(t_505)
|
||||||
|
t_507 = F.relu(t_506)
|
||||||
|
t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
|
||||||
|
t_508 = self.n_Conv_65(t_507_padded)
|
||||||
|
t_509 = F.relu(t_508)
|
||||||
|
t_510 = self.n_Conv_66(t_509)
|
||||||
|
t_511 = torch.add(t_510, t_505)
|
||||||
|
t_512 = F.relu(t_511)
|
||||||
|
t_513 = self.n_Conv_67(t_512)
|
||||||
|
t_514 = F.relu(t_513)
|
||||||
|
t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
|
||||||
|
t_515 = self.n_Conv_68(t_514_padded)
|
||||||
|
t_516 = F.relu(t_515)
|
||||||
|
t_517 = self.n_Conv_69(t_516)
|
||||||
|
t_518 = torch.add(t_517, t_512)
|
||||||
|
t_519 = F.relu(t_518)
|
||||||
|
t_520 = self.n_Conv_70(t_519)
|
||||||
|
t_521 = F.relu(t_520)
|
||||||
|
t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
|
||||||
|
t_522 = self.n_Conv_71(t_521_padded)
|
||||||
|
t_523 = F.relu(t_522)
|
||||||
|
t_524 = self.n_Conv_72(t_523)
|
||||||
|
t_525 = torch.add(t_524, t_519)
|
||||||
|
t_526 = F.relu(t_525)
|
||||||
|
t_527 = self.n_Conv_73(t_526)
|
||||||
|
t_528 = F.relu(t_527)
|
||||||
|
t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
|
||||||
|
t_529 = self.n_Conv_74(t_528_padded)
|
||||||
|
t_530 = F.relu(t_529)
|
||||||
|
t_531 = self.n_Conv_75(t_530)
|
||||||
|
t_532 = torch.add(t_531, t_526)
|
||||||
|
t_533 = F.relu(t_532)
|
||||||
|
t_534 = self.n_Conv_76(t_533)
|
||||||
|
t_535 = F.relu(t_534)
|
||||||
|
t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
|
||||||
|
t_536 = self.n_Conv_77(t_535_padded)
|
||||||
|
t_537 = F.relu(t_536)
|
||||||
|
t_538 = self.n_Conv_78(t_537)
|
||||||
|
t_539 = torch.add(t_538, t_533)
|
||||||
|
t_540 = F.relu(t_539)
|
||||||
|
t_541 = self.n_Conv_79(t_540)
|
||||||
|
t_542 = F.relu(t_541)
|
||||||
|
t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
|
||||||
|
t_543 = self.n_Conv_80(t_542_padded)
|
||||||
|
t_544 = F.relu(t_543)
|
||||||
|
t_545 = self.n_Conv_81(t_544)
|
||||||
|
t_546 = torch.add(t_545, t_540)
|
||||||
|
t_547 = F.relu(t_546)
|
||||||
|
t_548 = self.n_Conv_82(t_547)
|
||||||
|
t_549 = F.relu(t_548)
|
||||||
|
t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
|
||||||
|
t_550 = self.n_Conv_83(t_549_padded)
|
||||||
|
t_551 = F.relu(t_550)
|
||||||
|
t_552 = self.n_Conv_84(t_551)
|
||||||
|
t_553 = torch.add(t_552, t_547)
|
||||||
|
t_554 = F.relu(t_553)
|
||||||
|
t_555 = self.n_Conv_85(t_554)
|
||||||
|
t_556 = F.relu(t_555)
|
||||||
|
t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
|
||||||
|
t_557 = self.n_Conv_86(t_556_padded)
|
||||||
|
t_558 = F.relu(t_557)
|
||||||
|
t_559 = self.n_Conv_87(t_558)
|
||||||
|
t_560 = torch.add(t_559, t_554)
|
||||||
|
t_561 = F.relu(t_560)
|
||||||
|
t_562 = self.n_Conv_88(t_561)
|
||||||
|
t_563 = F.relu(t_562)
|
||||||
|
t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
|
||||||
|
t_564 = self.n_Conv_89(t_563_padded)
|
||||||
|
t_565 = F.relu(t_564)
|
||||||
|
t_566 = self.n_Conv_90(t_565)
|
||||||
|
t_567 = torch.add(t_566, t_561)
|
||||||
|
t_568 = F.relu(t_567)
|
||||||
|
t_569 = self.n_Conv_91(t_568)
|
||||||
|
t_570 = F.relu(t_569)
|
||||||
|
t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
|
||||||
|
t_571 = self.n_Conv_92(t_570_padded)
|
||||||
|
t_572 = F.relu(t_571)
|
||||||
|
t_573 = self.n_Conv_93(t_572)
|
||||||
|
t_574 = torch.add(t_573, t_568)
|
||||||
|
t_575 = F.relu(t_574)
|
||||||
|
t_576 = self.n_Conv_94(t_575)
|
||||||
|
t_577 = F.relu(t_576)
|
||||||
|
t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
|
||||||
|
t_578 = self.n_Conv_95(t_577_padded)
|
||||||
|
t_579 = F.relu(t_578)
|
||||||
|
t_580 = self.n_Conv_96(t_579)
|
||||||
|
t_581 = torch.add(t_580, t_575)
|
||||||
|
t_582 = F.relu(t_581)
|
||||||
|
t_583 = self.n_Conv_97(t_582)
|
||||||
|
t_584 = F.relu(t_583)
|
||||||
|
t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
|
||||||
|
t_585 = self.n_Conv_98(t_584_padded)
|
||||||
|
t_586 = F.relu(t_585)
|
||||||
|
t_587 = self.n_Conv_99(t_586)
|
||||||
|
t_588 = self.n_Conv_100(t_582)
|
||||||
|
t_589 = torch.add(t_587, t_588)
|
||||||
|
t_590 = F.relu(t_589)
|
||||||
|
t_591 = self.n_Conv_101(t_590)
|
||||||
|
t_592 = F.relu(t_591)
|
||||||
|
t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
|
||||||
|
t_593 = self.n_Conv_102(t_592_padded)
|
||||||
|
t_594 = F.relu(t_593)
|
||||||
|
t_595 = self.n_Conv_103(t_594)
|
||||||
|
t_596 = torch.add(t_595, t_590)
|
||||||
|
t_597 = F.relu(t_596)
|
||||||
|
t_598 = self.n_Conv_104(t_597)
|
||||||
|
t_599 = F.relu(t_598)
|
||||||
|
t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
|
||||||
|
t_600 = self.n_Conv_105(t_599_padded)
|
||||||
|
t_601 = F.relu(t_600)
|
||||||
|
t_602 = self.n_Conv_106(t_601)
|
||||||
|
t_603 = torch.add(t_602, t_597)
|
||||||
|
t_604 = F.relu(t_603)
|
||||||
|
t_605 = self.n_Conv_107(t_604)
|
||||||
|
t_606 = F.relu(t_605)
|
||||||
|
t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
|
||||||
|
t_607 = self.n_Conv_108(t_606_padded)
|
||||||
|
t_608 = F.relu(t_607)
|
||||||
|
t_609 = self.n_Conv_109(t_608)
|
||||||
|
t_610 = torch.add(t_609, t_604)
|
||||||
|
t_611 = F.relu(t_610)
|
||||||
|
t_612 = self.n_Conv_110(t_611)
|
||||||
|
t_613 = F.relu(t_612)
|
||||||
|
t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
|
||||||
|
t_614 = self.n_Conv_111(t_613_padded)
|
||||||
|
t_615 = F.relu(t_614)
|
||||||
|
t_616 = self.n_Conv_112(t_615)
|
||||||
|
t_617 = torch.add(t_616, t_611)
|
||||||
|
t_618 = F.relu(t_617)
|
||||||
|
t_619 = self.n_Conv_113(t_618)
|
||||||
|
t_620 = F.relu(t_619)
|
||||||
|
t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
|
||||||
|
t_621 = self.n_Conv_114(t_620_padded)
|
||||||
|
t_622 = F.relu(t_621)
|
||||||
|
t_623 = self.n_Conv_115(t_622)
|
||||||
|
t_624 = torch.add(t_623, t_618)
|
||||||
|
t_625 = F.relu(t_624)
|
||||||
|
t_626 = self.n_Conv_116(t_625)
|
||||||
|
t_627 = F.relu(t_626)
|
||||||
|
t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
|
||||||
|
t_628 = self.n_Conv_117(t_627_padded)
|
||||||
|
t_629 = F.relu(t_628)
|
||||||
|
t_630 = self.n_Conv_118(t_629)
|
||||||
|
t_631 = torch.add(t_630, t_625)
|
||||||
|
t_632 = F.relu(t_631)
|
||||||
|
t_633 = self.n_Conv_119(t_632)
|
||||||
|
t_634 = F.relu(t_633)
|
||||||
|
t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
|
||||||
|
t_635 = self.n_Conv_120(t_634_padded)
|
||||||
|
t_636 = F.relu(t_635)
|
||||||
|
t_637 = self.n_Conv_121(t_636)
|
||||||
|
t_638 = torch.add(t_637, t_632)
|
||||||
|
t_639 = F.relu(t_638)
|
||||||
|
t_640 = self.n_Conv_122(t_639)
|
||||||
|
t_641 = F.relu(t_640)
|
||||||
|
t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
|
||||||
|
t_642 = self.n_Conv_123(t_641_padded)
|
||||||
|
t_643 = F.relu(t_642)
|
||||||
|
t_644 = self.n_Conv_124(t_643)
|
||||||
|
t_645 = torch.add(t_644, t_639)
|
||||||
|
t_646 = F.relu(t_645)
|
||||||
|
t_647 = self.n_Conv_125(t_646)
|
||||||
|
t_648 = F.relu(t_647)
|
||||||
|
t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
|
||||||
|
t_649 = self.n_Conv_126(t_648_padded)
|
||||||
|
t_650 = F.relu(t_649)
|
||||||
|
t_651 = self.n_Conv_127(t_650)
|
||||||
|
t_652 = torch.add(t_651, t_646)
|
||||||
|
t_653 = F.relu(t_652)
|
||||||
|
t_654 = self.n_Conv_128(t_653)
|
||||||
|
t_655 = F.relu(t_654)
|
||||||
|
t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
|
||||||
|
t_656 = self.n_Conv_129(t_655_padded)
|
||||||
|
t_657 = F.relu(t_656)
|
||||||
|
t_658 = self.n_Conv_130(t_657)
|
||||||
|
t_659 = torch.add(t_658, t_653)
|
||||||
|
t_660 = F.relu(t_659)
|
||||||
|
t_661 = self.n_Conv_131(t_660)
|
||||||
|
t_662 = F.relu(t_661)
|
||||||
|
t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
|
||||||
|
t_663 = self.n_Conv_132(t_662_padded)
|
||||||
|
t_664 = F.relu(t_663)
|
||||||
|
t_665 = self.n_Conv_133(t_664)
|
||||||
|
t_666 = torch.add(t_665, t_660)
|
||||||
|
t_667 = F.relu(t_666)
|
||||||
|
t_668 = self.n_Conv_134(t_667)
|
||||||
|
t_669 = F.relu(t_668)
|
||||||
|
t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
|
||||||
|
t_670 = self.n_Conv_135(t_669_padded)
|
||||||
|
t_671 = F.relu(t_670)
|
||||||
|
t_672 = self.n_Conv_136(t_671)
|
||||||
|
t_673 = torch.add(t_672, t_667)
|
||||||
|
t_674 = F.relu(t_673)
|
||||||
|
t_675 = self.n_Conv_137(t_674)
|
||||||
|
t_676 = F.relu(t_675)
|
||||||
|
t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
|
||||||
|
t_677 = self.n_Conv_138(t_676_padded)
|
||||||
|
t_678 = F.relu(t_677)
|
||||||
|
t_679 = self.n_Conv_139(t_678)
|
||||||
|
t_680 = torch.add(t_679, t_674)
|
||||||
|
t_681 = F.relu(t_680)
|
||||||
|
t_682 = self.n_Conv_140(t_681)
|
||||||
|
t_683 = F.relu(t_682)
|
||||||
|
t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
|
||||||
|
t_684 = self.n_Conv_141(t_683_padded)
|
||||||
|
t_685 = F.relu(t_684)
|
||||||
|
t_686 = self.n_Conv_142(t_685)
|
||||||
|
t_687 = torch.add(t_686, t_681)
|
||||||
|
t_688 = F.relu(t_687)
|
||||||
|
t_689 = self.n_Conv_143(t_688)
|
||||||
|
t_690 = F.relu(t_689)
|
||||||
|
t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
|
||||||
|
t_691 = self.n_Conv_144(t_690_padded)
|
||||||
|
t_692 = F.relu(t_691)
|
||||||
|
t_693 = self.n_Conv_145(t_692)
|
||||||
|
t_694 = torch.add(t_693, t_688)
|
||||||
|
t_695 = F.relu(t_694)
|
||||||
|
t_696 = self.n_Conv_146(t_695)
|
||||||
|
t_697 = F.relu(t_696)
|
||||||
|
t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
|
||||||
|
t_698 = self.n_Conv_147(t_697_padded)
|
||||||
|
t_699 = F.relu(t_698)
|
||||||
|
t_700 = self.n_Conv_148(t_699)
|
||||||
|
t_701 = torch.add(t_700, t_695)
|
||||||
|
t_702 = F.relu(t_701)
|
||||||
|
t_703 = self.n_Conv_149(t_702)
|
||||||
|
t_704 = F.relu(t_703)
|
||||||
|
t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
|
||||||
|
t_705 = self.n_Conv_150(t_704_padded)
|
||||||
|
t_706 = F.relu(t_705)
|
||||||
|
t_707 = self.n_Conv_151(t_706)
|
||||||
|
t_708 = torch.add(t_707, t_702)
|
||||||
|
t_709 = F.relu(t_708)
|
||||||
|
t_710 = self.n_Conv_152(t_709)
|
||||||
|
t_711 = F.relu(t_710)
|
||||||
|
t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
|
||||||
|
t_712 = self.n_Conv_153(t_711_padded)
|
||||||
|
t_713 = F.relu(t_712)
|
||||||
|
t_714 = self.n_Conv_154(t_713)
|
||||||
|
t_715 = torch.add(t_714, t_709)
|
||||||
|
t_716 = F.relu(t_715)
|
||||||
|
t_717 = self.n_Conv_155(t_716)
|
||||||
|
t_718 = F.relu(t_717)
|
||||||
|
t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
|
||||||
|
t_719 = self.n_Conv_156(t_718_padded)
|
||||||
|
t_720 = F.relu(t_719)
|
||||||
|
t_721 = self.n_Conv_157(t_720)
|
||||||
|
t_722 = torch.add(t_721, t_716)
|
||||||
|
t_723 = F.relu(t_722)
|
||||||
|
t_724 = self.n_Conv_158(t_723)
|
||||||
|
t_725 = self.n_Conv_159(t_723)
|
||||||
|
t_726 = F.relu(t_725)
|
||||||
|
t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
|
||||||
|
t_727 = self.n_Conv_160(t_726_padded)
|
||||||
|
t_728 = F.relu(t_727)
|
||||||
|
t_729 = self.n_Conv_161(t_728)
|
||||||
|
t_730 = torch.add(t_729, t_724)
|
||||||
|
t_731 = F.relu(t_730)
|
||||||
|
t_732 = self.n_Conv_162(t_731)
|
||||||
|
t_733 = F.relu(t_732)
|
||||||
|
t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
|
||||||
|
t_734 = self.n_Conv_163(t_733_padded)
|
||||||
|
t_735 = F.relu(t_734)
|
||||||
|
t_736 = self.n_Conv_164(t_735)
|
||||||
|
t_737 = torch.add(t_736, t_731)
|
||||||
|
t_738 = F.relu(t_737)
|
||||||
|
t_739 = self.n_Conv_165(t_738)
|
||||||
|
t_740 = F.relu(t_739)
|
||||||
|
t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
|
||||||
|
t_741 = self.n_Conv_166(t_740_padded)
|
||||||
|
t_742 = F.relu(t_741)
|
||||||
|
t_743 = self.n_Conv_167(t_742)
|
||||||
|
t_744 = torch.add(t_743, t_738)
|
||||||
|
t_745 = F.relu(t_744)
|
||||||
|
t_746 = self.n_Conv_168(t_745)
|
||||||
|
t_747 = self.n_Conv_169(t_745)
|
||||||
|
t_748 = F.relu(t_747)
|
||||||
|
t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
|
||||||
|
t_749 = self.n_Conv_170(t_748_padded)
|
||||||
|
t_750 = F.relu(t_749)
|
||||||
|
t_751 = self.n_Conv_171(t_750)
|
||||||
|
t_752 = torch.add(t_751, t_746)
|
||||||
|
t_753 = F.relu(t_752)
|
||||||
|
t_754 = self.n_Conv_172(t_753)
|
||||||
|
t_755 = F.relu(t_754)
|
||||||
|
t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
|
||||||
|
t_756 = self.n_Conv_173(t_755_padded)
|
||||||
|
t_757 = F.relu(t_756)
|
||||||
|
t_758 = self.n_Conv_174(t_757)
|
||||||
|
t_759 = torch.add(t_758, t_753)
|
||||||
|
t_760 = F.relu(t_759)
|
||||||
|
t_761 = self.n_Conv_175(t_760)
|
||||||
|
t_762 = F.relu(t_761)
|
||||||
|
t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
|
||||||
|
t_763 = self.n_Conv_176(t_762_padded)
|
||||||
|
t_764 = F.relu(t_763)
|
||||||
|
t_765 = self.n_Conv_177(t_764)
|
||||||
|
t_766 = torch.add(t_765, t_760)
|
||||||
|
t_767 = F.relu(t_766)
|
||||||
|
t_768 = self.n_Conv_178(t_767)
|
||||||
|
t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
|
||||||
|
t_770 = torch.squeeze(t_769, 3)
|
||||||
|
t_770 = torch.squeeze(t_770, 2)
|
||||||
|
t_771 = torch.sigmoid(t_770)
|
||||||
|
return t_771
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict, **kwargs):
|
||||||
|
self.tags = state_dict.get('tags', [])
|
||||||
|
|
||||||
|
super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
|
||||||
|
|
||||||
@ -1,80 +1,463 @@
|
|||||||
# this file is taken from https://github.com/xinntao/ESRGAN
|
# this file is adapted from https://github.com/victorca25/iNNfer
|
||||||
|
|
||||||
|
import math
|
||||||
import functools
|
import functools
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
def make_layer(block, n_layers):
|
####################
|
||||||
layers = []
|
# RRDBNet Generator
|
||||||
for _ in range(n_layers):
|
####################
|
||||||
layers.append(block())
|
|
||||||
return nn.Sequential(*layers)
|
|
||||||
|
|
||||||
|
class RRDBNet(nn.Module):
|
||||||
|
def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
|
||||||
|
act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
|
||||||
|
finalact=None, gaussian_noise=False, plus=False):
|
||||||
|
super(RRDBNet, self).__init__()
|
||||||
|
n_upscale = int(math.log(upscale, 2))
|
||||||
|
if upscale == 3:
|
||||||
|
n_upscale = 1
|
||||||
|
|
||||||
class ResidualDenseBlock_5C(nn.Module):
|
self.resrgan_scale = 0
|
||||||
def __init__(self, nf=64, gc=32, bias=True):
|
if in_nc % 16 == 0:
|
||||||
super(ResidualDenseBlock_5C, self).__init__()
|
self.resrgan_scale = 1
|
||||||
# gc: growth channel, i.e. intermediate channels
|
elif in_nc != 4 and in_nc % 4 == 0:
|
||||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
self.resrgan_scale = 2
|
||||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
# initialization
|
fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
||||||
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||||
|
norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
|
||||||
|
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
|
||||||
|
LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
|
||||||
|
|
||||||
def forward(self, x):
|
if upsample_mode == 'upconv':
|
||||||
x1 = self.lrelu(self.conv1(x))
|
upsample_block = upconv_block
|
||||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
elif upsample_mode == 'pixelshuffle':
|
||||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
upsample_block = pixelshuffle_block
|
||||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
else:
|
||||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
||||||
return x5 * 0.2 + x
|
if upscale == 3:
|
||||||
|
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
||||||
|
else:
|
||||||
|
upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
|
||||||
|
HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
|
||||||
|
HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
||||||
|
|
||||||
|
outact = act(finalact) if finalact else None
|
||||||
|
|
||||||
|
self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
|
||||||
|
*upsampler, HR_conv0, HR_conv1, outact)
|
||||||
|
|
||||||
|
def forward(self, x, outm=None):
|
||||||
|
if self.resrgan_scale == 1:
|
||||||
|
feat = pixel_unshuffle(x, scale=4)
|
||||||
|
elif self.resrgan_scale == 2:
|
||||||
|
feat = pixel_unshuffle(x, scale=2)
|
||||||
|
else:
|
||||||
|
feat = x
|
||||||
|
|
||||||
|
return self.model(feat)
|
||||||
|
|
||||||
|
|
||||||
class RRDB(nn.Module):
|
class RRDB(nn.Module):
|
||||||
'''Residual in Residual Dense Block'''
|
"""
|
||||||
|
Residual in Residual Dense Block
|
||||||
|
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, nf, gc=32):
|
def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||||
|
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
||||||
|
spectral_norm=False, gaussian_noise=False, plus=False):
|
||||||
super(RRDB, self).__init__()
|
super(RRDB, self).__init__()
|
||||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
# This is for backwards compatibility with existing models
|
||||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
if nr == 3:
|
||||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||||
|
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||||
|
gaussian_noise=gaussian_noise, plus=plus)
|
||||||
|
self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||||
|
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||||
|
gaussian_noise=gaussian_noise, plus=plus)
|
||||||
|
self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||||
|
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||||
|
gaussian_noise=gaussian_noise, plus=plus)
|
||||||
|
else:
|
||||||
|
RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||||
|
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||||
|
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
|
||||||
|
self.RDBs = nn.Sequential(*RDB_list)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out = self.RDB1(x)
|
if hasattr(self, 'RDB1'):
|
||||||
out = self.RDB2(out)
|
out = self.RDB1(x)
|
||||||
out = self.RDB3(out)
|
out = self.RDB2(out)
|
||||||
|
out = self.RDB3(out)
|
||||||
|
else:
|
||||||
|
out = self.RDBs(x)
|
||||||
return out * 0.2 + x
|
return out * 0.2 + x
|
||||||
|
|
||||||
|
|
||||||
class RRDBNet(nn.Module):
|
class ResidualDenseBlock_5C(nn.Module):
|
||||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
|
"""
|
||||||
super(RRDBNet, self).__init__()
|
Residual Dense Block
|
||||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
||||||
|
Modified options that can be used:
|
||||||
|
- "Partial Convolution based Padding" arXiv:1811.11718
|
||||||
|
- "Spectral normalization" arXiv:1802.05957
|
||||||
|
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
||||||
|
{Rakotonirina} and A. {Rasoanaivo}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||||
|
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
||||||
|
spectral_norm=False, gaussian_noise=False, plus=False):
|
||||||
|
super(ResidualDenseBlock_5C, self).__init__()
|
||||||
|
|
||||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
self.noise = GaussianNoise() if gaussian_noise else None
|
||||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
self.conv1x1 = conv1x1(nf, gc) if plus else None
|
||||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
#### upsampling
|
|
||||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
|
||||||
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||||
|
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||||
|
spectral_norm=spectral_norm)
|
||||||
|
self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||||
|
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||||
|
spectral_norm=spectral_norm)
|
||||||
|
self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||||
|
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||||
|
spectral_norm=spectral_norm)
|
||||||
|
self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||||
|
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||||
|
spectral_norm=spectral_norm)
|
||||||
|
if mode == 'CNA':
|
||||||
|
last_act = None
|
||||||
|
else:
|
||||||
|
last_act = act_type
|
||||||
|
self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
|
||||||
|
norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
|
||||||
|
spectral_norm=spectral_norm)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
fea = self.conv_first(x)
|
x1 = self.conv1(x)
|
||||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
x2 = self.conv2(torch.cat((x, x1), 1))
|
||||||
fea = fea + trunk
|
if self.conv1x1:
|
||||||
|
x2 = x2 + self.conv1x1(x)
|
||||||
|
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
||||||
|
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
||||||
|
if self.conv1x1:
|
||||||
|
x4 = x4 + x2
|
||||||
|
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||||
|
if self.noise:
|
||||||
|
return self.noise(x5.mul(0.2) + x)
|
||||||
|
else:
|
||||||
|
return x5 * 0.2 + x
|
||||||
|
|
||||||
|
|
||||||
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
####################
|
||||||
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
# ESRGANplus
|
||||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
####################
|
||||||
|
|
||||||
|
class GaussianNoise(nn.Module):
|
||||||
|
def __init__(self, sigma=0.1, is_relative_detach=False):
|
||||||
|
super().__init__()
|
||||||
|
self.sigma = sigma
|
||||||
|
self.is_relative_detach = is_relative_detach
|
||||||
|
self.noise = torch.tensor(0, dtype=torch.float)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.training and self.sigma != 0:
|
||||||
|
self.noise = self.noise.to(x.device)
|
||||||
|
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
||||||
|
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
||||||
|
x = x + sampled_noise
|
||||||
|
return x
|
||||||
|
|
||||||
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# SRVGGNetCompact
|
||||||
|
####################
|
||||||
|
|
||||||
|
class SRVGGNetCompact(nn.Module):
|
||||||
|
"""A compact VGG-style network structure for super-resolution.
|
||||||
|
This class is copied from https://github.com/xinntao/Real-ESRGAN
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
||||||
|
super(SRVGGNetCompact, self).__init__()
|
||||||
|
self.num_in_ch = num_in_ch
|
||||||
|
self.num_out_ch = num_out_ch
|
||||||
|
self.num_feat = num_feat
|
||||||
|
self.num_conv = num_conv
|
||||||
|
self.upscale = upscale
|
||||||
|
self.act_type = act_type
|
||||||
|
|
||||||
|
self.body = nn.ModuleList()
|
||||||
|
# the first conv
|
||||||
|
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
||||||
|
# the first activation
|
||||||
|
if act_type == 'relu':
|
||||||
|
activation = nn.ReLU(inplace=True)
|
||||||
|
elif act_type == 'prelu':
|
||||||
|
activation = nn.PReLU(num_parameters=num_feat)
|
||||||
|
elif act_type == 'leakyrelu':
|
||||||
|
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
|
self.body.append(activation)
|
||||||
|
|
||||||
|
# the body structure
|
||||||
|
for _ in range(num_conv):
|
||||||
|
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
||||||
|
# activation
|
||||||
|
if act_type == 'relu':
|
||||||
|
activation = nn.ReLU(inplace=True)
|
||||||
|
elif act_type == 'prelu':
|
||||||
|
activation = nn.PReLU(num_parameters=num_feat)
|
||||||
|
elif act_type == 'leakyrelu':
|
||||||
|
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
|
self.body.append(activation)
|
||||||
|
|
||||||
|
# the last conv
|
||||||
|
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
||||||
|
# upsample
|
||||||
|
self.upsampler = nn.PixelShuffle(upscale)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = x
|
||||||
|
for i in range(0, len(self.body)):
|
||||||
|
out = self.body[i](out)
|
||||||
|
|
||||||
|
out = self.upsampler(out)
|
||||||
|
# add the nearest upsampled image, so that the network learns the residual
|
||||||
|
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
||||||
|
out += base
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# Upsampler
|
||||||
|
####################
|
||||||
|
|
||||||
|
class Upsample(nn.Module):
|
||||||
|
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
|
||||||
|
The input data is assumed to be of the form
|
||||||
|
`minibatch x channels x [optional depth] x [optional height] x width`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||||
|
super(Upsample, self).__init__()
|
||||||
|
if isinstance(scale_factor, tuple):
|
||||||
|
self.scale_factor = tuple(float(factor) for factor in scale_factor)
|
||||||
|
else:
|
||||||
|
self.scale_factor = float(scale_factor) if scale_factor else None
|
||||||
|
self.mode = mode
|
||||||
|
self.size = size
|
||||||
|
self.align_corners = align_corners
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
if self.scale_factor is not None:
|
||||||
|
info = 'scale_factor=' + str(self.scale_factor)
|
||||||
|
else:
|
||||||
|
info = 'size=' + str(self.size)
|
||||||
|
info += ', mode=' + self.mode
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def pixel_unshuffle(x, scale):
|
||||||
|
""" Pixel unshuffle.
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input feature with shape (b, c, hh, hw).
|
||||||
|
scale (int): Downsample ratio.
|
||||||
|
Returns:
|
||||||
|
Tensor: the pixel unshuffled feature.
|
||||||
|
"""
|
||||||
|
b, c, hh, hw = x.size()
|
||||||
|
out_channel = c * (scale**2)
|
||||||
|
assert hh % scale == 0 and hw % scale == 0
|
||||||
|
h = hh // scale
|
||||||
|
w = hw // scale
|
||||||
|
x_view = x.view(b, c, h, scale, w, scale)
|
||||||
|
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
||||||
|
|
||||||
|
|
||||||
|
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
||||||
|
pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
|
||||||
|
"""
|
||||||
|
Pixel shuffle layer
|
||||||
|
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
||||||
|
Neural Network, CVPR17)
|
||||||
|
"""
|
||||||
|
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
|
||||||
|
pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
|
||||||
|
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||||
|
|
||||||
|
n = norm(norm_type, out_nc) if norm_type else None
|
||||||
|
a = act(act_type) if act_type else None
|
||||||
|
return sequential(conv, pixel_shuffle, n, a)
|
||||||
|
|
||||||
|
|
||||||
|
def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
||||||
|
pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
|
||||||
|
""" Upconv layer """
|
||||||
|
upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
|
||||||
|
upsample = Upsample(scale_factor=upscale_factor, mode=mode)
|
||||||
|
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
|
||||||
|
pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
|
||||||
|
return sequential(upsample, conv)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# Basic blocks
|
||||||
|
####################
|
||||||
|
|
||||||
|
|
||||||
|
def make_layer(basic_block, num_basic_block, **kwarg):
|
||||||
|
"""Make layers by stacking the same blocks.
|
||||||
|
Args:
|
||||||
|
basic_block (nn.module): nn.module class for basic block. (block)
|
||||||
|
num_basic_block (int): number of blocks. (n_layers)
|
||||||
|
Returns:
|
||||||
|
nn.Sequential: Stacked blocks in nn.Sequential.
|
||||||
|
"""
|
||||||
|
layers = []
|
||||||
|
for _ in range(num_basic_block):
|
||||||
|
layers.append(basic_block(**kwarg))
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
|
||||||
|
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
|
||||||
|
""" activation helper """
|
||||||
|
act_type = act_type.lower()
|
||||||
|
if act_type == 'relu':
|
||||||
|
layer = nn.ReLU(inplace)
|
||||||
|
elif act_type in ('leakyrelu', 'lrelu'):
|
||||||
|
layer = nn.LeakyReLU(neg_slope, inplace)
|
||||||
|
elif act_type == 'prelu':
|
||||||
|
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
||||||
|
elif act_type == 'tanh': # [-1, 1] range output
|
||||||
|
layer = nn.Tanh()
|
||||||
|
elif act_type == 'sigmoid': # [0, 1] range output
|
||||||
|
layer = nn.Sigmoid()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
|
||||||
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
class Identity(nn.Module):
|
||||||
|
def __init__(self, *kwargs):
|
||||||
|
super(Identity, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x, *kwargs):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def norm(norm_type, nc):
|
||||||
|
""" Return a normalization layer """
|
||||||
|
norm_type = norm_type.lower()
|
||||||
|
if norm_type == 'batch':
|
||||||
|
layer = nn.BatchNorm2d(nc, affine=True)
|
||||||
|
elif norm_type == 'instance':
|
||||||
|
layer = nn.InstanceNorm2d(nc, affine=False)
|
||||||
|
elif norm_type == 'none':
|
||||||
|
def norm_layer(x): return Identity()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
|
||||||
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
def pad(pad_type, padding):
|
||||||
|
""" padding layer helper """
|
||||||
|
pad_type = pad_type.lower()
|
||||||
|
if padding == 0:
|
||||||
|
return None
|
||||||
|
if pad_type == 'reflect':
|
||||||
|
layer = nn.ReflectionPad2d(padding)
|
||||||
|
elif pad_type == 'replicate':
|
||||||
|
layer = nn.ReplicationPad2d(padding)
|
||||||
|
elif pad_type == 'zero':
|
||||||
|
layer = nn.ZeroPad2d(padding)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
|
||||||
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
def get_valid_padding(kernel_size, dilation):
|
||||||
|
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
||||||
|
padding = (kernel_size - 1) // 2
|
||||||
|
return padding
|
||||||
|
|
||||||
|
|
||||||
|
class ShortcutBlock(nn.Module):
|
||||||
|
""" Elementwise sum the output of a submodule to its input """
|
||||||
|
def __init__(self, submodule):
|
||||||
|
super(ShortcutBlock, self).__init__()
|
||||||
|
self.sub = submodule
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = x + self.sub(x)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
|
||||||
|
|
||||||
|
|
||||||
|
def sequential(*args):
|
||||||
|
""" Flatten Sequential. It unwraps nn.Sequential. """
|
||||||
|
if len(args) == 1:
|
||||||
|
if isinstance(args[0], OrderedDict):
|
||||||
|
raise NotImplementedError('sequential does not support OrderedDict input.')
|
||||||
|
return args[0] # No sequential is needed.
|
||||||
|
modules = []
|
||||||
|
for module in args:
|
||||||
|
if isinstance(module, nn.Sequential):
|
||||||
|
for submodule in module.children():
|
||||||
|
modules.append(submodule)
|
||||||
|
elif isinstance(module, nn.Module):
|
||||||
|
modules.append(module)
|
||||||
|
return nn.Sequential(*modules)
|
||||||
|
|
||||||
|
|
||||||
|
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
|
||||||
|
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
||||||
|
spectral_norm=False):
|
||||||
|
""" Conv layer with padding, normalization, activation """
|
||||||
|
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
|
||||||
|
padding = get_valid_padding(kernel_size, dilation)
|
||||||
|
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
||||||
|
padding = padding if pad_type == 'zero' else 0
|
||||||
|
|
||||||
|
if convtype=='PartialConv2D':
|
||||||
|
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
|
elif convtype=='DeformConv2D':
|
||||||
|
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
|
elif convtype=='Conv3D':
|
||||||
|
c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
|
else:
|
||||||
|
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
|
|
||||||
|
if spectral_norm:
|
||||||
|
c = nn.utils.spectral_norm(c)
|
||||||
|
|
||||||
|
a = act(act_type) if act_type else None
|
||||||
|
if 'CNA' in mode:
|
||||||
|
n = norm(norm_type, out_nc) if norm_type else None
|
||||||
|
return sequential(p, c, n, a)
|
||||||
|
elif mode == 'NAC':
|
||||||
|
if norm_type is None and act_type is not None:
|
||||||
|
a = act(act_type, inplace=False)
|
||||||
|
n = norm(norm_type, in_nc) if norm_type else None
|
||||||
|
return sequential(n, a, p, c)
|
||||||
|
|||||||
@ -0,0 +1,99 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import git
|
||||||
|
|
||||||
|
from modules import paths, shared
|
||||||
|
|
||||||
|
extensions = []
|
||||||
|
extensions_dir = os.path.join(paths.script_path, "extensions")
|
||||||
|
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
|
||||||
|
|
||||||
|
|
||||||
|
def active():
|
||||||
|
return [x for x in extensions if x.enabled]
|
||||||
|
|
||||||
|
|
||||||
|
class Extension:
|
||||||
|
def __init__(self, name, path, enabled=True, is_builtin=False):
|
||||||
|
self.name = name
|
||||||
|
self.path = path
|
||||||
|
self.enabled = enabled
|
||||||
|
self.status = ''
|
||||||
|
self.can_update = False
|
||||||
|
self.is_builtin = is_builtin
|
||||||
|
|
||||||
|
repo = None
|
||||||
|
try:
|
||||||
|
if os.path.exists(os.path.join(path, ".git")):
|
||||||
|
repo = git.Repo(path)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error reading github repository info from {path}:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
if repo is None or repo.bare:
|
||||||
|
self.remote = None
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
self.remote = next(repo.remote().urls, None)
|
||||||
|
self.status = 'unknown'
|
||||||
|
except Exception:
|
||||||
|
self.remote = None
|
||||||
|
|
||||||
|
def list_files(self, subdir, extension):
|
||||||
|
from modules import scripts
|
||||||
|
|
||||||
|
dirpath = os.path.join(self.path, subdir)
|
||||||
|
if not os.path.isdir(dirpath):
|
||||||
|
return []
|
||||||
|
|
||||||
|
res = []
|
||||||
|
for filename in sorted(os.listdir(dirpath)):
|
||||||
|
res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
|
||||||
|
|
||||||
|
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def check_updates(self):
|
||||||
|
repo = git.Repo(self.path)
|
||||||
|
for fetch in repo.remote().fetch("--dry-run"):
|
||||||
|
if fetch.flags != fetch.HEAD_UPTODATE:
|
||||||
|
self.can_update = True
|
||||||
|
self.status = "behind"
|
||||||
|
return
|
||||||
|
|
||||||
|
self.can_update = False
|
||||||
|
self.status = "latest"
|
||||||
|
|
||||||
|
def fetch_and_reset_hard(self):
|
||||||
|
repo = git.Repo(self.path)
|
||||||
|
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
||||||
|
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
||||||
|
repo.git.fetch('--all')
|
||||||
|
repo.git.reset('--hard', 'origin')
|
||||||
|
|
||||||
|
|
||||||
|
def list_extensions():
|
||||||
|
extensions.clear()
|
||||||
|
|
||||||
|
if not os.path.isdir(extensions_dir):
|
||||||
|
return
|
||||||
|
|
||||||
|
paths = []
|
||||||
|
for dirname in [extensions_dir, extensions_builtin_dir]:
|
||||||
|
if not os.path.isdir(dirname):
|
||||||
|
return
|
||||||
|
|
||||||
|
for extension_dirname in sorted(os.listdir(dirname)):
|
||||||
|
path = os.path.join(dirname, extension_dirname)
|
||||||
|
if not os.path.isdir(path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
|
||||||
|
|
||||||
|
for dirname, path, is_builtin in paths:
|
||||||
|
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
|
||||||
|
extensions.append(extension)
|
||||||
|
|
||||||
@ -1,183 +0,0 @@
|
|||||||
import os
|
|
||||||
import shutil
|
|
||||||
import sys
|
|
||||||
|
|
||||||
def traverse_all_files(output_dir, image_list, curr_dir=None):
|
|
||||||
curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
|
|
||||||
try:
|
|
||||||
f_list = os.listdir(curr_path)
|
|
||||||
except:
|
|
||||||
if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt":
|
|
||||||
image_list.append(curr_dir)
|
|
||||||
return image_list
|
|
||||||
for file in f_list:
|
|
||||||
file = file if curr_dir is None else os.path.join(curr_dir, file)
|
|
||||||
file_path = os.path.join(curr_path, file)
|
|
||||||
if file[-4:] == ".txt":
|
|
||||||
pass
|
|
||||||
elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0:
|
|
||||||
image_list.append(file)
|
|
||||||
else:
|
|
||||||
image_list = traverse_all_files(output_dir, image_list, file)
|
|
||||||
return image_list
|
|
||||||
|
|
||||||
|
|
||||||
def get_recent_images(dir_name, page_index, step, image_index, tabname):
|
|
||||||
page_index = int(page_index)
|
|
||||||
image_list = []
|
|
||||||
if not os.path.exists(dir_name):
|
|
||||||
pass
|
|
||||||
elif os.path.isdir(dir_name):
|
|
||||||
image_list = traverse_all_files(dir_name, image_list)
|
|
||||||
image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file)))
|
|
||||||
else:
|
|
||||||
print(f'ERROR: "{dir_name}" is not a directory. Check the path in the settings.', file=sys.stderr)
|
|
||||||
num = 48 if tabname != "extras" else 12
|
|
||||||
max_page_index = len(image_list) // num + 1
|
|
||||||
page_index = max_page_index if page_index == -1 else page_index + step
|
|
||||||
page_index = 1 if page_index < 1 else page_index
|
|
||||||
page_index = max_page_index if page_index > max_page_index else page_index
|
|
||||||
idx_frm = (page_index - 1) * num
|
|
||||||
image_list = image_list[idx_frm:idx_frm + num]
|
|
||||||
image_index = int(image_index)
|
|
||||||
if image_index < 0 or image_index > len(image_list) - 1:
|
|
||||||
current_file = None
|
|
||||||
hidden = None
|
|
||||||
else:
|
|
||||||
current_file = image_list[int(image_index)]
|
|
||||||
hidden = os.path.join(dir_name, current_file)
|
|
||||||
return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, ""
|
|
||||||
|
|
||||||
|
|
||||||
def first_page_click(dir_name, page_index, image_index, tabname):
|
|
||||||
return get_recent_images(dir_name, 1, 0, image_index, tabname)
|
|
||||||
|
|
||||||
|
|
||||||
def end_page_click(dir_name, page_index, image_index, tabname):
|
|
||||||
return get_recent_images(dir_name, -1, 0, image_index, tabname)
|
|
||||||
|
|
||||||
|
|
||||||
def prev_page_click(dir_name, page_index, image_index, tabname):
|
|
||||||
return get_recent_images(dir_name, page_index, -1, image_index, tabname)
|
|
||||||
|
|
||||||
|
|
||||||
def next_page_click(dir_name, page_index, image_index, tabname):
|
|
||||||
return get_recent_images(dir_name, page_index, 1, image_index, tabname)
|
|
||||||
|
|
||||||
|
|
||||||
def page_index_change(dir_name, page_index, image_index, tabname):
|
|
||||||
return get_recent_images(dir_name, page_index, 0, image_index, tabname)
|
|
||||||
|
|
||||||
|
|
||||||
def show_image_info(num, image_path, filenames):
|
|
||||||
# print(f"select image {num}")
|
|
||||||
file = filenames[int(num)]
|
|
||||||
return file, num, os.path.join(image_path, file)
|
|
||||||
|
|
||||||
|
|
||||||
def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index):
|
|
||||||
if name == "":
|
|
||||||
return filenames, delete_num
|
|
||||||
else:
|
|
||||||
delete_num = int(delete_num)
|
|
||||||
index = list(filenames).index(name)
|
|
||||||
i = 0
|
|
||||||
new_file_list = []
|
|
||||||
for name in filenames:
|
|
||||||
if i >= index and i < index + delete_num:
|
|
||||||
path = os.path.join(dir_name, name)
|
|
||||||
if os.path.exists(path):
|
|
||||||
print(f"Delete file {path}")
|
|
||||||
os.remove(path)
|
|
||||||
txt_file = os.path.splitext(path)[0] + ".txt"
|
|
||||||
if os.path.exists(txt_file):
|
|
||||||
os.remove(txt_file)
|
|
||||||
else:
|
|
||||||
print(f"Not exists file {path}")
|
|
||||||
else:
|
|
||||||
new_file_list.append(name)
|
|
||||||
i += 1
|
|
||||||
return new_file_list, 1
|
|
||||||
|
|
||||||
|
|
||||||
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
|
|
||||||
if opts.outdir_samples != "":
|
|
||||||
dir_name = opts.outdir_samples
|
|
||||||
elif tabname == "txt2img":
|
|
||||||
dir_name = opts.outdir_txt2img_samples
|
|
||||||
elif tabname == "img2img":
|
|
||||||
dir_name = opts.outdir_img2img_samples
|
|
||||||
elif tabname == "extras":
|
|
||||||
dir_name = opts.outdir_extras_samples
|
|
||||||
else:
|
|
||||||
return
|
|
||||||
with gr.Row():
|
|
||||||
renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page")
|
|
||||||
first_page = gr.Button('First Page')
|
|
||||||
prev_page = gr.Button('Prev Page')
|
|
||||||
page_index = gr.Number(value=1, label="Page Index")
|
|
||||||
next_page = gr.Button('Next Page')
|
|
||||||
end_page = gr.Button('End Page')
|
|
||||||
with gr.Row(elem_id=tabname + "_images_history"):
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column(scale=2):
|
|
||||||
history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
|
|
||||||
with gr.Row():
|
|
||||||
delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
|
|
||||||
delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Row():
|
|
||||||
pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
|
|
||||||
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
img_file_info = gr.Textbox(label="Generate Info", interactive=False)
|
|
||||||
img_file_name = gr.Textbox(label="File Name", interactive=False)
|
|
||||||
with gr.Row():
|
|
||||||
# hiden items
|
|
||||||
|
|
||||||
img_path = gr.Textbox(dir_name.rstrip("/"), visible=False)
|
|
||||||
tabname_box = gr.Textbox(tabname, visible=False)
|
|
||||||
image_index = gr.Textbox(value=-1, visible=False)
|
|
||||||
set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False)
|
|
||||||
filenames = gr.State()
|
|
||||||
hidden = gr.Image(type="pil", visible=False)
|
|
||||||
info1 = gr.Textbox(visible=False)
|
|
||||||
info2 = gr.Textbox(visible=False)
|
|
||||||
|
|
||||||
# turn pages
|
|
||||||
gallery_inputs = [img_path, page_index, image_index, tabname_box]
|
|
||||||
gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name]
|
|
||||||
|
|
||||||
first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
|
||||||
next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
|
||||||
prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
|
||||||
end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
|
||||||
page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
|
||||||
renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
|
||||||
# page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index])
|
|
||||||
|
|
||||||
# other funcitons
|
|
||||||
set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden])
|
|
||||||
img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
|
|
||||||
delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
|
|
||||||
hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
|
|
||||||
|
|
||||||
# pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
|
|
||||||
switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
|
|
||||||
switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
|
|
||||||
|
|
||||||
|
|
||||||
def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
|
|
||||||
with gr.Blocks(analytics_enabled=False) as images_history:
|
|
||||||
with gr.Tabs() as tabs:
|
|
||||||
with gr.Tab("txt2img history"):
|
|
||||||
with gr.Blocks(analytics_enabled=False) as images_history_txt2img:
|
|
||||||
show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict)
|
|
||||||
with gr.Tab("img2img history"):
|
|
||||||
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
|
|
||||||
show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict)
|
|
||||||
with gr.Tab("extras history"):
|
|
||||||
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
|
|
||||||
show_images_history(gr, opts, "extras", run_pnginfo, switch_dict)
|
|
||||||
return images_history
|
|
||||||
@ -0,0 +1,5 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
|
||||||
|
if "--xformers" not in "".join(sys.argv):
|
||||||
|
sys.modules["xformers"] = None
|
||||||
@ -1,42 +0,0 @@
|
|||||||
import torch
|
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
|
||||||
from transformers import AutoFeatureExtractor
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
import modules.shared as shared
|
|
||||||
|
|
||||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
|
||||||
safety_feature_extractor = None
|
|
||||||
safety_checker = None
|
|
||||||
|
|
||||||
def numpy_to_pil(images):
|
|
||||||
"""
|
|
||||||
Convert a numpy image or a batch of images to a PIL image.
|
|
||||||
"""
|
|
||||||
if images.ndim == 3:
|
|
||||||
images = images[None, ...]
|
|
||||||
images = (images * 255).round().astype("uint8")
|
|
||||||
pil_images = [Image.fromarray(image) for image in images]
|
|
||||||
|
|
||||||
return pil_images
|
|
||||||
|
|
||||||
# check and replace nsfw content
|
|
||||||
def check_safety(x_image):
|
|
||||||
global safety_feature_extractor, safety_checker
|
|
||||||
|
|
||||||
if safety_feature_extractor is None:
|
|
||||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
|
||||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
|
||||||
|
|
||||||
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
|
||||||
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
|
||||||
|
|
||||||
return x_checked_image, has_nsfw_concept
|
|
||||||
|
|
||||||
|
|
||||||
def censor_batch(x):
|
|
||||||
x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
|
|
||||||
x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
|
|
||||||
x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
|
|
||||||
|
|
||||||
return x
|
|
||||||
@ -0,0 +1,315 @@
|
|||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from collections import namedtuple
|
||||||
|
import inspect
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from gradio import Blocks
|
||||||
|
|
||||||
|
|
||||||
|
def report_exception(c, job):
|
||||||
|
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageSaveParams:
|
||||||
|
def __init__(self, image, p, filename, pnginfo):
|
||||||
|
self.image = image
|
||||||
|
"""the PIL image itself"""
|
||||||
|
|
||||||
|
self.p = p
|
||||||
|
"""p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
|
||||||
|
|
||||||
|
self.filename = filename
|
||||||
|
"""name of file that the image would be saved to"""
|
||||||
|
|
||||||
|
self.pnginfo = pnginfo
|
||||||
|
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
|
||||||
|
|
||||||
|
|
||||||
|
class CFGDenoiserParams:
|
||||||
|
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
|
||||||
|
self.x = x
|
||||||
|
"""Latent image representation in the process of being denoised"""
|
||||||
|
|
||||||
|
self.image_cond = image_cond
|
||||||
|
"""Conditioning image"""
|
||||||
|
|
||||||
|
self.sigma = sigma
|
||||||
|
"""Current sigma noise step value"""
|
||||||
|
|
||||||
|
self.sampling_step = sampling_step
|
||||||
|
"""Current Sampling step number"""
|
||||||
|
|
||||||
|
self.total_sampling_steps = total_sampling_steps
|
||||||
|
"""Total number of sampling steps planned"""
|
||||||
|
|
||||||
|
|
||||||
|
class UiTrainTabParams:
|
||||||
|
def __init__(self, txt2img_preview_params):
|
||||||
|
self.txt2img_preview_params = txt2img_preview_params
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGridLoopParams:
|
||||||
|
def __init__(self, imgs, cols, rows):
|
||||||
|
self.imgs = imgs
|
||||||
|
self.cols = cols
|
||||||
|
self.rows = rows
|
||||||
|
|
||||||
|
|
||||||
|
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
||||||
|
callback_map = dict(
|
||||||
|
callbacks_app_started=[],
|
||||||
|
callbacks_model_loaded=[],
|
||||||
|
callbacks_ui_tabs=[],
|
||||||
|
callbacks_ui_train_tabs=[],
|
||||||
|
callbacks_ui_settings=[],
|
||||||
|
callbacks_before_image_saved=[],
|
||||||
|
callbacks_image_saved=[],
|
||||||
|
callbacks_cfg_denoiser=[],
|
||||||
|
callbacks_before_component=[],
|
||||||
|
callbacks_after_component=[],
|
||||||
|
callbacks_image_grid=[],
|
||||||
|
callbacks_infotext_pasted=[],
|
||||||
|
callbacks_script_unloaded=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_callbacks():
|
||||||
|
for callback_list in callback_map.values():
|
||||||
|
callback_list.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
||||||
|
for c in callback_map['callbacks_app_started']:
|
||||||
|
try:
|
||||||
|
c.callback(demo, app)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'app_started_callback')
|
||||||
|
|
||||||
|
|
||||||
|
def model_loaded_callback(sd_model):
|
||||||
|
for c in callback_map['callbacks_model_loaded']:
|
||||||
|
try:
|
||||||
|
c.callback(sd_model)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'model_loaded_callback')
|
||||||
|
|
||||||
|
|
||||||
|
def ui_tabs_callback():
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for c in callback_map['callbacks_ui_tabs']:
|
||||||
|
try:
|
||||||
|
res += c.callback() or []
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'ui_tabs_callback')
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def ui_train_tabs_callback(params: UiTrainTabParams):
|
||||||
|
for c in callback_map['callbacks_ui_train_tabs']:
|
||||||
|
try:
|
||||||
|
c.callback(params)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'callbacks_ui_train_tabs')
|
||||||
|
|
||||||
|
|
||||||
|
def ui_settings_callback():
|
||||||
|
for c in callback_map['callbacks_ui_settings']:
|
||||||
|
try:
|
||||||
|
c.callback()
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'ui_settings_callback')
|
||||||
|
|
||||||
|
|
||||||
|
def before_image_saved_callback(params: ImageSaveParams):
|
||||||
|
for c in callback_map['callbacks_before_image_saved']:
|
||||||
|
try:
|
||||||
|
c.callback(params)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'before_image_saved_callback')
|
||||||
|
|
||||||
|
|
||||||
|
def image_saved_callback(params: ImageSaveParams):
|
||||||
|
for c in callback_map['callbacks_image_saved']:
|
||||||
|
try:
|
||||||
|
c.callback(params)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'image_saved_callback')
|
||||||
|
|
||||||
|
|
||||||
|
def cfg_denoiser_callback(params: CFGDenoiserParams):
|
||||||
|
for c in callback_map['callbacks_cfg_denoiser']:
|
||||||
|
try:
|
||||||
|
c.callback(params)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'cfg_denoiser_callback')
|
||||||
|
|
||||||
|
|
||||||
|
def before_component_callback(component, **kwargs):
|
||||||
|
for c in callback_map['callbacks_before_component']:
|
||||||
|
try:
|
||||||
|
c.callback(component, **kwargs)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'before_component_callback')
|
||||||
|
|
||||||
|
|
||||||
|
def after_component_callback(component, **kwargs):
|
||||||
|
for c in callback_map['callbacks_after_component']:
|
||||||
|
try:
|
||||||
|
c.callback(component, **kwargs)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'after_component_callback')
|
||||||
|
|
||||||
|
|
||||||
|
def image_grid_callback(params: ImageGridLoopParams):
|
||||||
|
for c in callback_map['callbacks_image_grid']:
|
||||||
|
try:
|
||||||
|
c.callback(params)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'image_grid')
|
||||||
|
|
||||||
|
|
||||||
|
def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
|
||||||
|
for c in callback_map['callbacks_infotext_pasted']:
|
||||||
|
try:
|
||||||
|
c.callback(infotext, params)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'infotext_pasted')
|
||||||
|
|
||||||
|
|
||||||
|
def script_unloaded_callback():
|
||||||
|
for c in reversed(callback_map['callbacks_script_unloaded']):
|
||||||
|
try:
|
||||||
|
c.callback()
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'script_unloaded')
|
||||||
|
|
||||||
|
|
||||||
|
def add_callback(callbacks, fun):
|
||||||
|
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||||
|
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||||
|
|
||||||
|
callbacks.append(ScriptCallback(filename, fun))
|
||||||
|
|
||||||
|
|
||||||
|
def remove_current_script_callbacks():
|
||||||
|
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||||
|
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||||
|
if filename == 'unknown file':
|
||||||
|
return
|
||||||
|
for callback_list in callback_map.values():
|
||||||
|
for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
|
||||||
|
callback_list.remove(callback_to_remove)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_callbacks_for_function(callback_func):
|
||||||
|
for callback_list in callback_map.values():
|
||||||
|
for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
|
||||||
|
callback_list.remove(callback_to_remove)
|
||||||
|
|
||||||
|
|
||||||
|
def on_app_started(callback):
|
||||||
|
"""register a function to be called when the webui started, the gradio `Block` component and
|
||||||
|
fastapi `FastAPI` object are passed as the arguments"""
|
||||||
|
add_callback(callback_map['callbacks_app_started'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_model_loaded(callback):
|
||||||
|
"""register a function to be called when the stable diffusion model is created; the model is
|
||||||
|
passed as an argument; this function is also called when the script is reloaded. """
|
||||||
|
add_callback(callback_map['callbacks_model_loaded'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_ui_tabs(callback):
|
||||||
|
"""register a function to be called when the UI is creating new tabs.
|
||||||
|
The function must either return a None, which means no new tabs to be added, or a list, where
|
||||||
|
each element is a tuple:
|
||||||
|
(gradio_component, title, elem_id)
|
||||||
|
|
||||||
|
gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
|
||||||
|
title is tab text displayed to user in the UI
|
||||||
|
elem_id is HTML id for the tab
|
||||||
|
"""
|
||||||
|
add_callback(callback_map['callbacks_ui_tabs'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_ui_train_tabs(callback):
|
||||||
|
"""register a function to be called when the UI is creating new tabs for the train tab.
|
||||||
|
Create your new tabs with gr.Tab.
|
||||||
|
"""
|
||||||
|
add_callback(callback_map['callbacks_ui_train_tabs'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_ui_settings(callback):
|
||||||
|
"""register a function to be called before UI settings are populated; add your settings
|
||||||
|
by using shared.opts.add_option(shared.OptionInfo(...)) """
|
||||||
|
add_callback(callback_map['callbacks_ui_settings'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_before_image_saved(callback):
|
||||||
|
"""register a function to be called before an image is saved to a file.
|
||||||
|
The callback is called with one argument:
|
||||||
|
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
|
||||||
|
"""
|
||||||
|
add_callback(callback_map['callbacks_before_image_saved'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_image_saved(callback):
|
||||||
|
"""register a function to be called after an image is saved to a file.
|
||||||
|
The callback is called with one argument:
|
||||||
|
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
|
||||||
|
"""
|
||||||
|
add_callback(callback_map['callbacks_image_saved'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_cfg_denoiser(callback):
|
||||||
|
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
|
||||||
|
The callback is called with one argument:
|
||||||
|
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
|
||||||
|
"""
|
||||||
|
add_callback(callback_map['callbacks_cfg_denoiser'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_before_component(callback):
|
||||||
|
"""register a function to be called before a component is created.
|
||||||
|
The callback is called with arguments:
|
||||||
|
- component - gradio component that is about to be created.
|
||||||
|
- **kwargs - args to gradio.components.IOComponent.__init__ function
|
||||||
|
|
||||||
|
Use elem_id/label fields of kwargs to figure out which component it is.
|
||||||
|
This can be useful to inject your own components somewhere in the middle of vanilla UI.
|
||||||
|
"""
|
||||||
|
add_callback(callback_map['callbacks_before_component'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_after_component(callback):
|
||||||
|
"""register a function to be called after a component is created. See on_before_component for more."""
|
||||||
|
add_callback(callback_map['callbacks_after_component'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_image_grid(callback):
|
||||||
|
"""register a function to be called before making an image grid.
|
||||||
|
The callback is called with one argument:
|
||||||
|
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
|
||||||
|
"""
|
||||||
|
add_callback(callback_map['callbacks_image_grid'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_infotext_pasted(callback):
|
||||||
|
"""register a function to be called before applying an infotext.
|
||||||
|
The callback is called with two arguments:
|
||||||
|
- infotext: str - raw infotext.
|
||||||
|
- result: Dict[str, any] - parsed infotext parameters.
|
||||||
|
"""
|
||||||
|
add_callback(callback_map['callbacks_infotext_pasted'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_script_unloaded(callback):
|
||||||
|
"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
|
||||||
|
the script did should be reverted here"""
|
||||||
|
|
||||||
|
add_callback(callback_map['callbacks_script_unloaded'], callback)
|
||||||
@ -0,0 +1,34 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from types import ModuleType
|
||||||
|
|
||||||
|
|
||||||
|
def load_module(path):
|
||||||
|
with open(path, "r", encoding="utf8") as file:
|
||||||
|
text = file.read()
|
||||||
|
|
||||||
|
compiled = compile(text, path, 'exec')
|
||||||
|
module = ModuleType(os.path.basename(path))
|
||||||
|
exec(compiled, module.__dict__)
|
||||||
|
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def preload_extensions(extensions_dir, parser):
|
||||||
|
if not os.path.isdir(extensions_dir):
|
||||||
|
return
|
||||||
|
|
||||||
|
for dirname in sorted(os.listdir(extensions_dir)):
|
||||||
|
preload_script = os.path.join(extensions_dir, dirname, "preload.py")
|
||||||
|
if not os.path.isfile(preload_script):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
module = load_module(preload_script)
|
||||||
|
if hasattr(module, 'preload'):
|
||||||
|
module.preload(parser)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
print(f"Error running preload() for {preload_script}", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
@ -0,0 +1,88 @@
|
|||||||
|
import ldm.modules.encoders.modules
|
||||||
|
import open_clip
|
||||||
|
import torch
|
||||||
|
import transformers.utils.hub
|
||||||
|
|
||||||
|
|
||||||
|
class DisableInitialization:
|
||||||
|
"""
|
||||||
|
When an object of this class enters a `with` block, it starts:
|
||||||
|
- preventing torch's layer initialization functions from working
|
||||||
|
- changes CLIP and OpenCLIP to not download model weights
|
||||||
|
- changes CLIP to not make requests to check if there is a new version of a file you already have
|
||||||
|
|
||||||
|
When it leaves the block, it reverts everything to how it was before.
|
||||||
|
|
||||||
|
Use it like this:
|
||||||
|
```
|
||||||
|
with DisableInitialization():
|
||||||
|
do_things()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.replaced = []
|
||||||
|
|
||||||
|
def replace(self, obj, field, func):
|
||||||
|
original = getattr(obj, field, None)
|
||||||
|
if original is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
self.replaced.append((obj, field, original))
|
||||||
|
setattr(obj, field, func)
|
||||||
|
|
||||||
|
return original
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
def do_nothing(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
|
||||||
|
return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
|
||||||
|
|
||||||
|
def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
|
||||||
|
|
||||||
|
def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
|
||||||
|
args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
|
||||||
|
return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
|
||||||
|
|
||||||
|
def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
|
||||||
|
|
||||||
|
# this file is always 404, prevent making request
|
||||||
|
if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
res = original(url, *args, local_files_only=True, **kwargs)
|
||||||
|
if res is None:
|
||||||
|
res = original(url, *args, local_files_only=False, **kwargs)
|
||||||
|
return res
|
||||||
|
except Exception as e:
|
||||||
|
return original(url, *args, local_files_only=False, **kwargs)
|
||||||
|
|
||||||
|
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
||||||
|
return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
|
||||||
|
|
||||||
|
def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
|
||||||
|
return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
|
||||||
|
|
||||||
|
def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
|
||||||
|
return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
|
||||||
|
|
||||||
|
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
|
||||||
|
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
|
||||||
|
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
|
||||||
|
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
||||||
|
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
||||||
|
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
||||||
|
self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
||||||
|
self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
||||||
|
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
for obj, field, original in self.replaced:
|
||||||
|
setattr(obj, field, original)
|
||||||
|
|
||||||
|
self.replaced.clear()
|
||||||
|
|
||||||
@ -0,0 +1,10 @@
|
|||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
def BasicTransformerBlock_forward(self, x, context=None):
|
||||||
|
return checkpoint(self._forward, x, context)
|
||||||
|
|
||||||
|
def AttentionBlock_forward(self, x):
|
||||||
|
return checkpoint(self._forward, x)
|
||||||
|
|
||||||
|
def ResBlock_forward(self, x, emb):
|
||||||
|
return checkpoint(self._forward, x, emb)
|
||||||
@ -0,0 +1,308 @@
|
|||||||
|
import math
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from modules import prompt_parser, devices, sd_hijack
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
|
|
||||||
|
class PromptChunk:
|
||||||
|
"""
|
||||||
|
This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
|
||||||
|
If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
|
||||||
|
Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
|
||||||
|
so just 75 tokens from prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.tokens = []
|
||||||
|
self.multipliers = []
|
||||||
|
self.fixes = []
|
||||||
|
|
||||||
|
|
||||||
|
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
||||||
|
"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
|
||||||
|
chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
|
||||||
|
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||||
|
"""A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
|
||||||
|
have unlimited prompt length and assign weights to tokens in prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, wrapped, hijack):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.wrapped = wrapped
|
||||||
|
"""Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
|
||||||
|
depending on model."""
|
||||||
|
|
||||||
|
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
|
||||||
|
self.chunk_length = 75
|
||||||
|
|
||||||
|
def empty_chunk(self):
|
||||||
|
"""creates an empty PromptChunk and returns it"""
|
||||||
|
|
||||||
|
chunk = PromptChunk()
|
||||||
|
chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
|
||||||
|
chunk.multipliers = [1.0] * (self.chunk_length + 2)
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
def get_target_prompt_token_count(self, token_count):
|
||||||
|
"""returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
|
||||||
|
|
||||||
|
return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
|
||||||
|
|
||||||
|
def tokenize(self, texts):
|
||||||
|
"""Converts a batch of texts into a batch of token ids"""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def encode_with_transformers(self, tokens):
|
||||||
|
"""
|
||||||
|
converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
|
||||||
|
All python lists with tokens are assumed to have same length, usually 77.
|
||||||
|
if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
|
||||||
|
model - can be 768 and 1024.
|
||||||
|
Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
|
"""Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
|
||||||
|
transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def tokenize_line(self, line):
|
||||||
|
"""
|
||||||
|
this transforms a single prompt into a list of PromptChunk objects - as many as needed to
|
||||||
|
represent the prompt.
|
||||||
|
Returns the list and the total number of tokens in the prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if opts.enable_emphasis:
|
||||||
|
parsed = prompt_parser.parse_prompt_attention(line)
|
||||||
|
else:
|
||||||
|
parsed = [[line, 1.0]]
|
||||||
|
|
||||||
|
tokenized = self.tokenize([text for text, _ in parsed])
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
chunk = PromptChunk()
|
||||||
|
token_count = 0
|
||||||
|
last_comma = -1
|
||||||
|
|
||||||
|
def next_chunk():
|
||||||
|
"""puts current chunk into the list of results and produces the next one - empty"""
|
||||||
|
nonlocal token_count
|
||||||
|
nonlocal last_comma
|
||||||
|
nonlocal chunk
|
||||||
|
|
||||||
|
token_count += len(chunk.tokens)
|
||||||
|
to_add = self.chunk_length - len(chunk.tokens)
|
||||||
|
if to_add > 0:
|
||||||
|
chunk.tokens += [self.id_end] * to_add
|
||||||
|
chunk.multipliers += [1.0] * to_add
|
||||||
|
|
||||||
|
chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
|
||||||
|
chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
|
||||||
|
|
||||||
|
last_comma = -1
|
||||||
|
chunks.append(chunk)
|
||||||
|
chunk = PromptChunk()
|
||||||
|
|
||||||
|
for tokens, (text, weight) in zip(tokenized, parsed):
|
||||||
|
position = 0
|
||||||
|
while position < len(tokens):
|
||||||
|
token = tokens[position]
|
||||||
|
|
||||||
|
if token == self.comma_token:
|
||||||
|
last_comma = len(chunk.tokens)
|
||||||
|
|
||||||
|
# this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
|
||||||
|
# is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
|
||||||
|
elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
|
||||||
|
break_location = last_comma + 1
|
||||||
|
|
||||||
|
reloc_tokens = chunk.tokens[break_location:]
|
||||||
|
reloc_mults = chunk.multipliers[break_location:]
|
||||||
|
|
||||||
|
chunk.tokens = chunk.tokens[:break_location]
|
||||||
|
chunk.multipliers = chunk.multipliers[:break_location]
|
||||||
|
|
||||||
|
next_chunk()
|
||||||
|
chunk.tokens = reloc_tokens
|
||||||
|
chunk.multipliers = reloc_mults
|
||||||
|
|
||||||
|
if len(chunk.tokens) == self.chunk_length:
|
||||||
|
next_chunk()
|
||||||
|
|
||||||
|
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
|
||||||
|
if embedding is None:
|
||||||
|
chunk.tokens.append(token)
|
||||||
|
chunk.multipliers.append(weight)
|
||||||
|
position += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
emb_len = int(embedding.vec.shape[0])
|
||||||
|
if len(chunk.tokens) + emb_len > self.chunk_length:
|
||||||
|
next_chunk()
|
||||||
|
|
||||||
|
chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
|
||||||
|
|
||||||
|
chunk.tokens += [0] * emb_len
|
||||||
|
chunk.multipliers += [weight] * emb_len
|
||||||
|
position += embedding_length_in_tokens
|
||||||
|
|
||||||
|
if len(chunk.tokens) > 0 or len(chunks) == 0:
|
||||||
|
next_chunk()
|
||||||
|
|
||||||
|
return chunks, token_count
|
||||||
|
|
||||||
|
def process_texts(self, texts):
|
||||||
|
"""
|
||||||
|
Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
|
||||||
|
length, in tokens, of all texts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
token_count = 0
|
||||||
|
|
||||||
|
cache = {}
|
||||||
|
batch_chunks = []
|
||||||
|
for line in texts:
|
||||||
|
if line in cache:
|
||||||
|
chunks = cache[line]
|
||||||
|
else:
|
||||||
|
chunks, current_token_count = self.tokenize_line(line)
|
||||||
|
token_count = max(current_token_count, token_count)
|
||||||
|
|
||||||
|
cache[line] = chunks
|
||||||
|
|
||||||
|
batch_chunks.append(chunks)
|
||||||
|
|
||||||
|
return batch_chunks, token_count
|
||||||
|
|
||||||
|
def forward(self, texts):
|
||||||
|
"""
|
||||||
|
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
|
||||||
|
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
|
||||||
|
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
|
||||||
|
An example shape returned by this function can be: (2, 77, 768).
|
||||||
|
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
|
||||||
|
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
|
||||||
|
"""
|
||||||
|
|
||||||
|
if opts.use_old_emphasis_implementation:
|
||||||
|
import modules.sd_hijack_clip_old
|
||||||
|
return modules.sd_hijack_clip_old.forward_old(self, texts)
|
||||||
|
|
||||||
|
batch_chunks, token_count = self.process_texts(texts)
|
||||||
|
|
||||||
|
used_embeddings = {}
|
||||||
|
chunk_count = max([len(x) for x in batch_chunks])
|
||||||
|
|
||||||
|
zs = []
|
||||||
|
for i in range(chunk_count):
|
||||||
|
batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
|
||||||
|
|
||||||
|
tokens = [x.tokens for x in batch_chunk]
|
||||||
|
multipliers = [x.multipliers for x in batch_chunk]
|
||||||
|
self.hijack.fixes = [x.fixes for x in batch_chunk]
|
||||||
|
|
||||||
|
for fixes in self.hijack.fixes:
|
||||||
|
for position, embedding in fixes:
|
||||||
|
used_embeddings[embedding.name] = embedding
|
||||||
|
|
||||||
|
z = self.process_tokens(tokens, multipliers)
|
||||||
|
zs.append(z)
|
||||||
|
|
||||||
|
if len(used_embeddings) > 0:
|
||||||
|
embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
|
||||||
|
self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
|
||||||
|
|
||||||
|
return torch.hstack(zs)
|
||||||
|
|
||||||
|
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||||
|
"""
|
||||||
|
sends one single prompt chunk to be encoded by transformers neural network.
|
||||||
|
remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
|
||||||
|
there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
|
||||||
|
Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
|
||||||
|
corresponds to one token.
|
||||||
|
"""
|
||||||
|
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
|
||||||
|
|
||||||
|
# this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
|
||||||
|
if self.id_end != self.id_pad:
|
||||||
|
for batch_pos in range(len(remade_batch_tokens)):
|
||||||
|
index = remade_batch_tokens[batch_pos].index(self.id_end)
|
||||||
|
tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
|
||||||
|
|
||||||
|
z = self.encode_with_transformers(tokens)
|
||||||
|
|
||||||
|
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||||
|
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
|
||||||
|
original_mean = z.mean()
|
||||||
|
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||||
|
new_mean = z.mean()
|
||||||
|
z = z * (original_mean / new_mean)
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
||||||
|
def __init__(self, wrapped, hijack):
|
||||||
|
super().__init__(wrapped, hijack)
|
||||||
|
self.tokenizer = wrapped.tokenizer
|
||||||
|
|
||||||
|
vocab = self.tokenizer.get_vocab()
|
||||||
|
|
||||||
|
self.comma_token = vocab.get(',</w>', None)
|
||||||
|
|
||||||
|
self.token_mults = {}
|
||||||
|
tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||||
|
for text, ident in tokens_with_parens:
|
||||||
|
mult = 1.0
|
||||||
|
for c in text:
|
||||||
|
if c == '[':
|
||||||
|
mult /= 1.1
|
||||||
|
if c == ']':
|
||||||
|
mult *= 1.1
|
||||||
|
if c == '(':
|
||||||
|
mult *= 1.1
|
||||||
|
if c == ')':
|
||||||
|
mult /= 1.1
|
||||||
|
|
||||||
|
if mult != 1.0:
|
||||||
|
self.token_mults[ident] = mult
|
||||||
|
|
||||||
|
self.id_start = self.wrapped.tokenizer.bos_token_id
|
||||||
|
self.id_end = self.wrapped.tokenizer.eos_token_id
|
||||||
|
self.id_pad = self.id_end
|
||||||
|
|
||||||
|
def tokenize(self, texts):
|
||||||
|
tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||||
|
|
||||||
|
return tokenized
|
||||||
|
|
||||||
|
def encode_with_transformers(self, tokens):
|
||||||
|
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||||
|
|
||||||
|
if opts.CLIP_stop_at_last_layers > 1:
|
||||||
|
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
|
||||||
|
z = self.wrapped.transformer.text_model.final_layer_norm(z)
|
||||||
|
else:
|
||||||
|
z = outputs.last_hidden_state
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
|
embedding_layer = self.wrapped.transformer.text_model.embeddings
|
||||||
|
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||||
|
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
|
||||||
|
|
||||||
|
return embedded
|
||||||
@ -0,0 +1,81 @@
|
|||||||
|
from modules import sd_hijack_clip
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
|
||||||
|
def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
|
||||||
|
id_start = self.id_start
|
||||||
|
id_end = self.id_end
|
||||||
|
maxlen = self.wrapped.max_length # you get to stay at 77
|
||||||
|
used_custom_terms = []
|
||||||
|
remade_batch_tokens = []
|
||||||
|
hijack_comments = []
|
||||||
|
hijack_fixes = []
|
||||||
|
token_count = 0
|
||||||
|
|
||||||
|
cache = {}
|
||||||
|
batch_tokens = self.tokenize(texts)
|
||||||
|
batch_multipliers = []
|
||||||
|
for tokens in batch_tokens:
|
||||||
|
tuple_tokens = tuple(tokens)
|
||||||
|
|
||||||
|
if tuple_tokens in cache:
|
||||||
|
remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
||||||
|
else:
|
||||||
|
fixes = []
|
||||||
|
remade_tokens = []
|
||||||
|
multipliers = []
|
||||||
|
mult = 1.0
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
while i < len(tokens):
|
||||||
|
token = tokens[i]
|
||||||
|
|
||||||
|
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||||
|
|
||||||
|
mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
|
||||||
|
if mult_change is not None:
|
||||||
|
mult *= mult_change
|
||||||
|
i += 1
|
||||||
|
elif embedding is None:
|
||||||
|
remade_tokens.append(token)
|
||||||
|
multipliers.append(mult)
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
emb_len = int(embedding.vec.shape[0])
|
||||||
|
fixes.append((len(remade_tokens), embedding))
|
||||||
|
remade_tokens += [0] * emb_len
|
||||||
|
multipliers += [mult] * emb_len
|
||||||
|
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||||
|
i += embedding_length_in_tokens
|
||||||
|
|
||||||
|
if len(remade_tokens) > maxlen - 2:
|
||||||
|
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||||
|
ovf = remade_tokens[maxlen - 2:]
|
||||||
|
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||||
|
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||||
|
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||||
|
|
||||||
|
token_count = len(remade_tokens)
|
||||||
|
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||||
|
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
||||||
|
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||||
|
|
||||||
|
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||||
|
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
||||||
|
|
||||||
|
remade_batch_tokens.append(remade_tokens)
|
||||||
|
hijack_fixes.append(fixes)
|
||||||
|
batch_multipliers.append(multipliers)
|
||||||
|
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||||
|
|
||||||
|
|
||||||
|
def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
|
||||||
|
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
|
||||||
|
|
||||||
|
self.hijack.comments += hijack_comments
|
||||||
|
|
||||||
|
if len(used_custom_terms) > 0:
|
||||||
|
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||||
|
|
||||||
|
self.hijack.fixes = hijack_fixes
|
||||||
|
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
||||||
@ -0,0 +1,111 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from einops import repeat
|
||||||
|
from omegaconf import ListConfig
|
||||||
|
|
||||||
|
import ldm.models.diffusion.ddpm
|
||||||
|
import ldm.models.diffusion.ddim
|
||||||
|
import ldm.models.diffusion.plms
|
||||||
|
|
||||||
|
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||||
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
|
||||||
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
|
def get_model_output(x, t):
|
||||||
|
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
t_in = torch.cat([t] * 2)
|
||||||
|
|
||||||
|
if isinstance(c, dict):
|
||||||
|
assert isinstance(unconditional_conditioning, dict)
|
||||||
|
c_in = dict()
|
||||||
|
for k in c:
|
||||||
|
if isinstance(c[k], list):
|
||||||
|
c_in[k] = [
|
||||||
|
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||||
|
for i in range(len(c[k]))
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||||
|
else:
|
||||||
|
c_in = torch.cat([unconditional_conditioning, c])
|
||||||
|
|
||||||
|
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||||
|
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||||
|
|
||||||
|
if score_corrector is not None:
|
||||||
|
assert self.model.parameterization == "eps"
|
||||||
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||||
|
|
||||||
|
return e_t
|
||||||
|
|
||||||
|
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||||
|
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||||
|
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||||
|
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||||
|
|
||||||
|
def get_x_prev_and_pred_x0(e_t, index):
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||||
|
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||||
|
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||||
|
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
if quantize_denoised:
|
||||||
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
if dynamic_threshold is not None:
|
||||||
|
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
if noise_dropout > 0.:
|
||||||
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
return x_prev, pred_x0
|
||||||
|
|
||||||
|
e_t = get_model_output(x, t)
|
||||||
|
if len(old_eps) == 0:
|
||||||
|
# Pseudo Improved Euler (2nd order)
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||||
|
e_t_next = get_model_output(x_prev, t_next)
|
||||||
|
e_t_prime = (e_t + e_t_next) / 2
|
||||||
|
elif len(old_eps) == 1:
|
||||||
|
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||||
|
elif len(old_eps) == 2:
|
||||||
|
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||||
|
elif len(old_eps) >= 3:
|
||||||
|
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||||
|
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||||
|
|
||||||
|
return x_prev, pred_x0, e_t
|
||||||
|
|
||||||
|
|
||||||
|
def should_hijack_inpainting(checkpoint_info):
|
||||||
|
from modules import sd_models
|
||||||
|
|
||||||
|
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||||
|
cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
|
||||||
|
|
||||||
|
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
|
||||||
|
|
||||||
|
|
||||||
|
def do_inpainting_hijack():
|
||||||
|
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
||||||
|
|
||||||
|
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
||||||
@ -0,0 +1,37 @@
|
|||||||
|
import open_clip.tokenizer
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from modules import sd_hijack_clip, devices
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
|
tokenizer = open_clip.tokenizer._tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
|
||||||
|
def __init__(self, wrapped, hijack):
|
||||||
|
super().__init__(wrapped, hijack)
|
||||||
|
|
||||||
|
self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
|
||||||
|
self.id_start = tokenizer.encoder["<start_of_text>"]
|
||||||
|
self.id_end = tokenizer.encoder["<end_of_text>"]
|
||||||
|
self.id_pad = 0
|
||||||
|
|
||||||
|
def tokenize(self, texts):
|
||||||
|
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
|
||||||
|
|
||||||
|
tokenized = [tokenizer.encode(text) for text in texts]
|
||||||
|
|
||||||
|
return tokenized
|
||||||
|
|
||||||
|
def encode_with_transformers(self, tokens):
|
||||||
|
# set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
|
||||||
|
z = self.wrapped.encode_with_transformer(tokens)
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
|
ids = tokenizer.encode(init_text)
|
||||||
|
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
||||||
|
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
|
||||||
|
|
||||||
|
return embedded
|
||||||
@ -0,0 +1,30 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class TorchHijackForUnet:
|
||||||
|
"""
|
||||||
|
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
||||||
|
this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
if item == 'cat':
|
||||||
|
return self.cat
|
||||||
|
|
||||||
|
if hasattr(torch, item):
|
||||||
|
return getattr(torch, item)
|
||||||
|
|
||||||
|
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
||||||
|
|
||||||
|
def cat(self, tensors, *args, **kwargs):
|
||||||
|
if len(tensors) == 2:
|
||||||
|
a, b = tensors
|
||||||
|
if a.shape[-2:] != b.shape[-2:]:
|
||||||
|
a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
|
||||||
|
|
||||||
|
tensors = (a, b)
|
||||||
|
|
||||||
|
return torch.cat(tensors, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
th = TorchHijackForUnet()
|
||||||
@ -0,0 +1,34 @@
|
|||||||
|
import open_clip.tokenizer
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from modules import sd_hijack_clip, devices
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
|
||||||
|
def __init__(self, wrapped, hijack):
|
||||||
|
super().__init__(wrapped, hijack)
|
||||||
|
|
||||||
|
self.id_start = wrapped.config.bos_token_id
|
||||||
|
self.id_end = wrapped.config.eos_token_id
|
||||||
|
self.id_pad = wrapped.config.pad_token_id
|
||||||
|
|
||||||
|
self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have </w> bits for comma
|
||||||
|
|
||||||
|
def encode_with_transformers(self, tokens):
|
||||||
|
# there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
|
||||||
|
# trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
|
||||||
|
# layer to work with - you have to use the last
|
||||||
|
|
||||||
|
attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
|
||||||
|
features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
|
||||||
|
z = features['projection_state']
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
|
embedding_layer = self.wrapped.roberta.embeddings
|
||||||
|
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||||
|
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
||||||
|
|
||||||
|
return embedded
|
||||||
@ -0,0 +1,243 @@
|
|||||||
|
import torch
|
||||||
|
import safetensors.torch
|
||||||
|
import os
|
||||||
|
import collections
|
||||||
|
from collections import namedtuple
|
||||||
|
from modules import shared, devices, script_callbacks, sd_models
|
||||||
|
from modules.paths import models_path
|
||||||
|
import glob
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
|
model_dir = "Stable-diffusion"
|
||||||
|
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||||
|
vae_dir = "VAE"
|
||||||
|
vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
|
||||||
|
|
||||||
|
|
||||||
|
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
|
||||||
|
|
||||||
|
|
||||||
|
default_vae_dict = {"auto": "auto", "None": None, None: None}
|
||||||
|
default_vae_list = ["auto", "None"]
|
||||||
|
|
||||||
|
|
||||||
|
default_vae_values = [default_vae_dict[x] for x in default_vae_list]
|
||||||
|
vae_dict = dict(default_vae_dict)
|
||||||
|
vae_list = list(default_vae_list)
|
||||||
|
first_load = True
|
||||||
|
|
||||||
|
|
||||||
|
base_vae = None
|
||||||
|
loaded_vae_file = None
|
||||||
|
checkpoint_info = None
|
||||||
|
|
||||||
|
checkpoints_loaded = collections.OrderedDict()
|
||||||
|
|
||||||
|
def get_base_vae(model):
|
||||||
|
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
||||||
|
return base_vae
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def store_base_vae(model):
|
||||||
|
global base_vae, checkpoint_info
|
||||||
|
if checkpoint_info != model.sd_checkpoint_info:
|
||||||
|
assert not loaded_vae_file, "Trying to store non-base VAE!"
|
||||||
|
base_vae = deepcopy(model.first_stage_model.state_dict())
|
||||||
|
checkpoint_info = model.sd_checkpoint_info
|
||||||
|
|
||||||
|
|
||||||
|
def delete_base_vae():
|
||||||
|
global base_vae, checkpoint_info
|
||||||
|
base_vae = None
|
||||||
|
checkpoint_info = None
|
||||||
|
|
||||||
|
|
||||||
|
def restore_base_vae(model):
|
||||||
|
global loaded_vae_file
|
||||||
|
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
|
||||||
|
print("Restoring base VAE")
|
||||||
|
_load_vae_dict(model, base_vae)
|
||||||
|
loaded_vae_file = None
|
||||||
|
delete_base_vae()
|
||||||
|
|
||||||
|
|
||||||
|
def get_filename(filepath):
|
||||||
|
return os.path.splitext(os.path.basename(filepath))[0]
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_vae_list(vae_path=vae_path, model_path=model_path):
|
||||||
|
global vae_dict, vae_list
|
||||||
|
res = {}
|
||||||
|
candidates = [
|
||||||
|
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
|
||||||
|
*glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
|
||||||
|
*glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True),
|
||||||
|
*glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
|
||||||
|
*glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True),
|
||||||
|
*glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True),
|
||||||
|
]
|
||||||
|
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
|
||||||
|
candidates.append(shared.cmd_opts.vae_path)
|
||||||
|
for filepath in candidates:
|
||||||
|
name = get_filename(filepath)
|
||||||
|
res[name] = filepath
|
||||||
|
vae_list.clear()
|
||||||
|
vae_list.extend(default_vae_list)
|
||||||
|
vae_list.extend(list(res.keys()))
|
||||||
|
vae_dict.clear()
|
||||||
|
vae_dict.update(res)
|
||||||
|
vae_dict.update(default_vae_dict)
|
||||||
|
return vae_list
|
||||||
|
|
||||||
|
|
||||||
|
def get_vae_from_settings(vae_file="auto"):
|
||||||
|
# else, we load from settings, if not set to be default
|
||||||
|
if vae_file == "auto" and shared.opts.sd_vae is not None:
|
||||||
|
# if saved VAE settings isn't recognized, fallback to auto
|
||||||
|
vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
|
||||||
|
# if VAE selected but not found, fallback to auto
|
||||||
|
if vae_file not in default_vae_values and not os.path.isfile(vae_file):
|
||||||
|
vae_file = "auto"
|
||||||
|
print(f"Selected VAE doesn't exist: {vae_file}")
|
||||||
|
return vae_file
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_vae(checkpoint_file=None, vae_file="auto"):
|
||||||
|
global first_load, vae_dict, vae_list
|
||||||
|
|
||||||
|
# if vae_file argument is provided, it takes priority, but not saved
|
||||||
|
if vae_file and vae_file not in default_vae_list:
|
||||||
|
if not os.path.isfile(vae_file):
|
||||||
|
print(f"VAE provided as function argument doesn't exist: {vae_file}")
|
||||||
|
vae_file = "auto"
|
||||||
|
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
|
||||||
|
if first_load and shared.cmd_opts.vae_path is not None:
|
||||||
|
if os.path.isfile(shared.cmd_opts.vae_path):
|
||||||
|
vae_file = shared.cmd_opts.vae_path
|
||||||
|
shared.opts.data['sd_vae'] = get_filename(vae_file)
|
||||||
|
else:
|
||||||
|
print(f"VAE provided as command line argument doesn't exist: {vae_file}")
|
||||||
|
# fallback to selector in settings, if vae selector not set to act as default fallback
|
||||||
|
if not shared.opts.sd_vae_as_default:
|
||||||
|
vae_file = get_vae_from_settings(vae_file)
|
||||||
|
# vae-path cmd arg takes priority for auto
|
||||||
|
if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
|
||||||
|
if os.path.isfile(shared.cmd_opts.vae_path):
|
||||||
|
vae_file = shared.cmd_opts.vae_path
|
||||||
|
print(f"Using VAE provided as command line argument: {vae_file}")
|
||||||
|
# if still not found, try look for ".vae.pt" beside model
|
||||||
|
model_path = os.path.splitext(checkpoint_file)[0]
|
||||||
|
if vae_file == "auto":
|
||||||
|
vae_file_try = model_path + ".vae.pt"
|
||||||
|
if os.path.isfile(vae_file_try):
|
||||||
|
vae_file = vae_file_try
|
||||||
|
print(f"Using VAE found similar to selected model: {vae_file}")
|
||||||
|
# if still not found, try look for ".vae.ckpt" beside model
|
||||||
|
if vae_file == "auto":
|
||||||
|
vae_file_try = model_path + ".vae.ckpt"
|
||||||
|
if os.path.isfile(vae_file_try):
|
||||||
|
vae_file = vae_file_try
|
||||||
|
print(f"Using VAE found similar to selected model: {vae_file}")
|
||||||
|
# if still not found, try look for ".vae.safetensors" beside model
|
||||||
|
if vae_file == "auto":
|
||||||
|
vae_file_try = model_path + ".vae.safetensors"
|
||||||
|
if os.path.isfile(vae_file_try):
|
||||||
|
vae_file = vae_file_try
|
||||||
|
print(f"Using VAE found similar to selected model: {vae_file}")
|
||||||
|
# No more fallbacks for auto
|
||||||
|
if vae_file == "auto":
|
||||||
|
vae_file = None
|
||||||
|
# Last check, just because
|
||||||
|
if vae_file and not os.path.exists(vae_file):
|
||||||
|
vae_file = None
|
||||||
|
|
||||||
|
return vae_file
|
||||||
|
|
||||||
|
|
||||||
|
def load_vae(model, vae_file=None):
|
||||||
|
global first_load, vae_dict, vae_list, loaded_vae_file
|
||||||
|
# save_settings = False
|
||||||
|
|
||||||
|
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
||||||
|
|
||||||
|
if vae_file:
|
||||||
|
if cache_enabled and vae_file in checkpoints_loaded:
|
||||||
|
# use vae checkpoint cache
|
||||||
|
print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
|
||||||
|
store_base_vae(model)
|
||||||
|
_load_vae_dict(model, checkpoints_loaded[vae_file])
|
||||||
|
else:
|
||||||
|
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
|
||||||
|
print(f"Loading VAE weights from: {vae_file}")
|
||||||
|
store_base_vae(model)
|
||||||
|
|
||||||
|
vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location)
|
||||||
|
vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||||
|
_load_vae_dict(model, vae_dict_1)
|
||||||
|
|
||||||
|
if cache_enabled:
|
||||||
|
# cache newly loaded vae
|
||||||
|
checkpoints_loaded[vae_file] = vae_dict_1.copy()
|
||||||
|
|
||||||
|
# clean up cache if limit is reached
|
||||||
|
if cache_enabled:
|
||||||
|
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
|
||||||
|
checkpoints_loaded.popitem(last=False) # LRU
|
||||||
|
|
||||||
|
# If vae used is not in dict, update it
|
||||||
|
# It will be removed on refresh though
|
||||||
|
vae_opt = get_filename(vae_file)
|
||||||
|
if vae_opt not in vae_dict:
|
||||||
|
vae_dict[vae_opt] = vae_file
|
||||||
|
vae_list.append(vae_opt)
|
||||||
|
elif loaded_vae_file:
|
||||||
|
restore_base_vae(model)
|
||||||
|
|
||||||
|
loaded_vae_file = vae_file
|
||||||
|
|
||||||
|
first_load = False
|
||||||
|
|
||||||
|
|
||||||
|
# don't call this from outside
|
||||||
|
def _load_vae_dict(model, vae_dict_1):
|
||||||
|
model.first_stage_model.load_state_dict(vae_dict_1)
|
||||||
|
model.first_stage_model.to(devices.dtype_vae)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_loaded_vae():
|
||||||
|
global loaded_vae_file
|
||||||
|
loaded_vae_file = None
|
||||||
|
|
||||||
|
|
||||||
|
def reload_vae_weights(sd_model=None, vae_file="auto"):
|
||||||
|
from modules import lowvram, devices, sd_hijack
|
||||||
|
|
||||||
|
if not sd_model:
|
||||||
|
sd_model = shared.sd_model
|
||||||
|
|
||||||
|
checkpoint_info = sd_model.sd_checkpoint_info
|
||||||
|
checkpoint_file = checkpoint_info.filename
|
||||||
|
vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
|
||||||
|
|
||||||
|
if loaded_vae_file == vae_file:
|
||||||
|
return
|
||||||
|
|
||||||
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
|
lowvram.send_everything_to_cpu()
|
||||||
|
else:
|
||||||
|
sd_model.to(devices.cpu)
|
||||||
|
|
||||||
|
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||||
|
|
||||||
|
load_vae(sd_model, vae_file)
|
||||||
|
|
||||||
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
|
|
||||||
|
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||||
|
sd_model.to(devices.device)
|
||||||
|
|
||||||
|
print("VAE Weights loaded.")
|
||||||
|
return sd_model
|
||||||
@ -0,0 +1,58 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from modules import devices, paths
|
||||||
|
|
||||||
|
sd_vae_approx_model = None
|
||||||
|
|
||||||
|
|
||||||
|
class VAEApprox(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(VAEApprox, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(4, 8, (7, 7))
|
||||||
|
self.conv2 = nn.Conv2d(8, 16, (5, 5))
|
||||||
|
self.conv3 = nn.Conv2d(16, 32, (3, 3))
|
||||||
|
self.conv4 = nn.Conv2d(32, 64, (3, 3))
|
||||||
|
self.conv5 = nn.Conv2d(64, 32, (3, 3))
|
||||||
|
self.conv6 = nn.Conv2d(32, 16, (3, 3))
|
||||||
|
self.conv7 = nn.Conv2d(16, 8, (3, 3))
|
||||||
|
self.conv8 = nn.Conv2d(8, 3, (3, 3))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
extra = 11
|
||||||
|
x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
|
||||||
|
x = nn.functional.pad(x, (extra, extra, extra, extra))
|
||||||
|
|
||||||
|
for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:
|
||||||
|
x = layer(x)
|
||||||
|
x = nn.functional.leaky_relu(x, 0.1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def model():
|
||||||
|
global sd_vae_approx_model
|
||||||
|
|
||||||
|
if sd_vae_approx_model is None:
|
||||||
|
sd_vae_approx_model = VAEApprox()
|
||||||
|
sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt")))
|
||||||
|
sd_vae_approx_model.eval()
|
||||||
|
sd_vae_approx_model.to(devices.device, devices.dtype)
|
||||||
|
|
||||||
|
return sd_vae_approx_model
|
||||||
|
|
||||||
|
|
||||||
|
def cheap_approximation(sample):
|
||||||
|
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
|
||||||
|
|
||||||
|
coefs = torch.tensor([
|
||||||
|
[0.298, 0.207, 0.208],
|
||||||
|
[0.187, 0.286, 0.173],
|
||||||
|
[-0.158, 0.189, 0.264],
|
||||||
|
[-0.184, -0.271, -0.473],
|
||||||
|
]).to(sample.device)
|
||||||
|
|
||||||
|
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
|
||||||
|
|
||||||
|
return x_sample
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue