Merge branch 'master' into xygrid_infotext_improvements
commit
32547f2721
@ -1,32 +0,0 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: bug-report
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**To Reproduce**
|
||||
Steps to reproduce the behavior:
|
||||
1. Go to '...'
|
||||
2. Click on '....'
|
||||
3. Scroll down to '....'
|
||||
4. See error
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Screenshots**
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
**Desktop (please complete the following information):**
|
||||
- OS: [e.g. Windows, Linux]
|
||||
- Browser [e.g. chrome, safari]
|
||||
- Commit revision [looks like this: e68484500f76a33ba477d5a99340ab30451e557b; can be seen when launching webui.bat, or obtained manually by running `git rev-parse HEAD`]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
@ -0,0 +1,83 @@
|
||||
name: Bug Report
|
||||
description: You think somethings is broken in the UI
|
||||
title: "[Bug]: "
|
||||
labels: ["bug-report"]
|
||||
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Is there an existing issue for this?
|
||||
description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
|
||||
options:
|
||||
- label: I have searched the existing issues and checked the recent builds/commits
|
||||
required: true
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
*Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
|
||||
- type: textarea
|
||||
id: what-did
|
||||
attributes:
|
||||
label: What happened?
|
||||
description: Tell us what happened in a very clear and simple way
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: steps
|
||||
attributes:
|
||||
label: Steps to reproduce the problem
|
||||
description: Please provide us with precise step by step information on how to reproduce the bug
|
||||
value: |
|
||||
1. Go to ....
|
||||
2. Press ....
|
||||
3. ...
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: what-should
|
||||
attributes:
|
||||
label: What should have happened?
|
||||
description: tell what you think the normal behavior should be
|
||||
validations:
|
||||
required: true
|
||||
- type: input
|
||||
id: commit
|
||||
attributes:
|
||||
label: Commit where the problem happens
|
||||
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
id: platforms
|
||||
attributes:
|
||||
label: What platforms do you use to access UI ?
|
||||
multiple: true
|
||||
options:
|
||||
- Windows
|
||||
- Linux
|
||||
- MacOS
|
||||
- iOS
|
||||
- Android
|
||||
- Other/Cloud
|
||||
- type: dropdown
|
||||
id: browsers
|
||||
attributes:
|
||||
label: What browsers do you use to access the UI ?
|
||||
multiple: true
|
||||
options:
|
||||
- Mozilla Firefox
|
||||
- Google Chrome
|
||||
- Brave
|
||||
- Apple Safari
|
||||
- Microsoft Edge
|
||||
- type: textarea
|
||||
id: cmdargs
|
||||
attributes:
|
||||
label: Command Line Arguments
|
||||
description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below
|
||||
render: Shell
|
||||
- type: textarea
|
||||
id: misc
|
||||
attributes:
|
||||
label: Additional information, context and logs
|
||||
description: Please provide us with any relevant additional info, context or log output.
|
||||
@ -0,0 +1,5 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: WebUI Community Support
|
||||
url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions
|
||||
about: Please ask and answer questions here.
|
||||
@ -1,20 +0,0 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
||||
@ -0,0 +1,40 @@
|
||||
name: Feature request
|
||||
description: Suggest an idea for this project
|
||||
title: "[Feature Request]: "
|
||||
labels: ["suggestion"]
|
||||
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Is there an existing issue for this?
|
||||
description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit.
|
||||
options:
|
||||
- label: I have searched the existing issues and checked the recent builds/commits
|
||||
required: true
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
*Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible*
|
||||
- type: textarea
|
||||
id: feature
|
||||
attributes:
|
||||
label: What would your feature do ?
|
||||
description: Tell us about your feature in a very clear and simple way, and what problem it would solve
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: workflow
|
||||
attributes:
|
||||
label: Proposed workflow
|
||||
description: Please provide us with step by step information on how you'd like the feature to be accessed and used
|
||||
value: |
|
||||
1. Go to ....
|
||||
2. Press ....
|
||||
3. ...
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: misc
|
||||
attributes:
|
||||
label: Additional information
|
||||
description: Add any other context or screenshots about the feature request here.
|
||||
@ -0,0 +1,28 @@
|
||||
# Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request!
|
||||
|
||||
If you have a large change, pay special attention to this paragraph:
|
||||
|
||||
> Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature.
|
||||
|
||||
Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on.
|
||||
|
||||
**Describe what this pull request is trying to achieve.**
|
||||
|
||||
A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code.
|
||||
|
||||
**Additional notes and description of your changes**
|
||||
|
||||
More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of.
|
||||
|
||||
**Environment this was tested in**
|
||||
|
||||
List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box.
|
||||
- OS: [e.g. Windows, Linux]
|
||||
- Browser [e.g. chrome, safari]
|
||||
- Graphics card [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB]
|
||||
|
||||
**Screenshots or videos of your changes**
|
||||
|
||||
If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made.
|
||||
|
||||
This is **required** for anything that touches the user interface.
|
||||
@ -0,0 +1,42 @@
|
||||
# See https://github.com/actions/starter-workflows/blob/1067f16ad8a1eac328834e4b0ae24f7d206f810d/ci/pylint.yml for original reference file
|
||||
name: Run Linting/Formatting on Pull Requests
|
||||
|
||||
on:
|
||||
- push
|
||||
- pull_request
|
||||
# See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#onpull_requestpull_request_targetbranchesbranches-ignore for syntax docs
|
||||
# if you want to filter out branches, delete the `- pull_request` and uncomment these lines :
|
||||
# pull_request:
|
||||
# branches:
|
||||
# - master
|
||||
# branches-ignore:
|
||||
# - development
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v3
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: 3.10.6
|
||||
- uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
- name: Install PyLint
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pylint
|
||||
# This lets PyLint check to see if it can resolve imports
|
||||
- name: Install dependencies
|
||||
run : |
|
||||
export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
|
||||
python launch.py
|
||||
- name: Analysing the code with pylint
|
||||
run: |
|
||||
pylint $(git ls-files '*.py')
|
||||
@ -0,0 +1,31 @@
|
||||
name: Run basic features tests on CPU with empty SD model
|
||||
|
||||
on:
|
||||
- push
|
||||
- pull_request
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v3
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.10.6
|
||||
- uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
|
||||
restore-keys: ${{ runner.os }}-pip-
|
||||
- name: Run tests
|
||||
run: python launch.py --tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
|
||||
- name: Upload main app stdout-stderr
|
||||
uses: actions/upload-artifact@v3
|
||||
if: always()
|
||||
with:
|
||||
name: stdout-stderr
|
||||
path: |
|
||||
test/stdout.txt
|
||||
test/stderr.txt
|
||||
@ -0,0 +1,3 @@
|
||||
# See https://pylint.pycqa.org/en/latest/user_guide/messages/message_control.html
|
||||
[MESSAGES CONTROL]
|
||||
disable=C,R,W,E,I
|
||||
@ -0,0 +1,12 @@
|
||||
* @AUTOMATIC1111
|
||||
|
||||
# if you were managing a localization and were removed from this file, this is because
|
||||
# the intended way to do localizations now is via extensions. See:
|
||||
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions
|
||||
# Make a repo with your localization and since you are still listed as a collaborator
|
||||
# you can add it to the wiki page yourself. This change is because some people complained
|
||||
# the git commit log is cluttered with things unrelated to almost everyone and
|
||||
# because I believe this is the best overall for the project to handle localizations almost
|
||||
# entirely without my oversight.
|
||||
|
||||
|
||||
|
@ -0,0 +1,72 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: modules.xlmr.BertSeriesModelWithTransformation
|
||||
params:
|
||||
name: "XLMR-Large"
|
||||
@ -0,0 +1,70 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
@ -0,0 +1,6 @@
|
||||
import os
|
||||
from modules import paths
|
||||
|
||||
|
||||
def preload(parser):
|
||||
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
|
||||
@ -0,0 +1,286 @@
|
||||
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
|
||||
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
|
||||
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
|
||||
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
import torch.nn.functional as F
|
||||
from contextlib import contextmanager
|
||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
import ldm.models.autoencoder
|
||||
|
||||
class VQModel(pl.LightningModule):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
batch_resize_range=None,
|
||||
scheduler_config=None,
|
||||
lr_g_factor=1.0,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
use_ema=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.n_embed = n_embed
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
||||
remap=remap,
|
||||
sane_index_shape=sane_index_shape)
|
||||
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels)==int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
self.batch_resize_range = batch_resize_range
|
||||
if self.batch_resize_range is not None:
|
||||
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
||||
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self)
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
self.scheduler_config = scheduler_config
|
||||
self.lr_g_factor = lr_g_factor
|
||||
|
||||
@contextmanager
|
||||
def ema_scope(self, context=None):
|
||||
if self.use_ema:
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
return quant, emb_loss, info
|
||||
|
||||
def encode_to_prequant(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, quant):
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
def decode_code(self, code_b):
|
||||
quant_b = self.quantize.embed_code(code_b)
|
||||
dec = self.decode(quant_b)
|
||||
return dec
|
||||
|
||||
def forward(self, input, return_pred_indices=False):
|
||||
quant, diff, (_,_,ind) = self.encode(input)
|
||||
dec = self.decode(quant)
|
||||
if return_pred_indices:
|
||||
return dec, diff, ind
|
||||
return dec, diff
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
if self.batch_resize_range is not None:
|
||||
lower_size = self.batch_resize_range[0]
|
||||
upper_size = self.batch_resize_range[1]
|
||||
if self.global_step <= 4:
|
||||
# do the first few batches with max size to avoid later oom
|
||||
new_resize = upper_size
|
||||
else:
|
||||
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
||||
if new_resize != x.shape[2]:
|
||||
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
||||
x = x.detach()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
# https://github.com/pytorch/pytorch/issues/37142
|
||||
# try not to fool the heuristics
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train",
|
||||
predicted_indices=ind)
|
||||
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, suffix=""):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val"+suffix,
|
||||
predicted_indices=ind
|
||||
)
|
||||
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val"+suffix,
|
||||
predicted_indices=ind
|
||||
)
|
||||
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log(f"val{suffix}/rec_loss", rec_loss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
self.log(f"val{suffix}/aeloss", aeloss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||
del log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr_d = self.learning_rate
|
||||
lr_g = self.lr_g_factor*self.learning_rate
|
||||
print("lr_d", lr_d)
|
||||
print("lr_g", lr_g)
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quantize.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr_g, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr_d, betas=(0.5, 0.9))
|
||||
|
||||
if self.scheduler_config is not None:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
scheduler = [
|
||||
{
|
||||
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
},
|
||||
{
|
||||
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
},
|
||||
]
|
||||
return [opt_ae, opt_disc], scheduler
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if only_inputs:
|
||||
log["inputs"] = x
|
||||
return log
|
||||
xrec, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = xrec
|
||||
if plot_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, _ = self(x)
|
||||
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
||||
log["reconstructions_ema"] = xrec_ema
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class VQModelInterface(VQModel):
|
||||
def __init__(self, embed_dim, *args, **kwargs):
|
||||
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, h, force_not_quantize=False):
|
||||
# also go through quantization layer
|
||||
if not force_not_quantize:
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
else:
|
||||
quant = h
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
setattr(ldm.models.autoencoder, "VQModel", VQModel)
|
||||
setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,6 @@
|
||||
import os
|
||||
from modules import paths
|
||||
|
||||
|
||||
def preload(parser):
|
||||
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
|
||||
@ -0,0 +1,265 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
from timm.models.layers import trunc_normal_, DropPath
|
||||
|
||||
|
||||
class WMSA(nn.Module):
|
||||
""" Self-attention module in Swin Transformer
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, output_dim, head_dim, window_size, type):
|
||||
super(WMSA, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.head_dim = head_dim
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.n_heads = input_dim // head_dim
|
||||
self.window_size = window_size
|
||||
self.type = type
|
||||
self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
|
||||
|
||||
self.relative_position_params = nn.Parameter(
|
||||
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
|
||||
|
||||
self.linear = nn.Linear(self.input_dim, self.output_dim)
|
||||
|
||||
trunc_normal_(self.relative_position_params, std=.02)
|
||||
self.relative_position_params = torch.nn.Parameter(
|
||||
self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
|
||||
2).transpose(
|
||||
0, 1))
|
||||
|
||||
def generate_mask(self, h, w, p, shift):
|
||||
""" generating the mask of SW-MSA
|
||||
Args:
|
||||
shift: shift parameters in CyclicShift.
|
||||
Returns:
|
||||
attn_mask: should be (1 1 w p p),
|
||||
"""
|
||||
# supporting square.
|
||||
attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
|
||||
if self.type == 'W':
|
||||
return attn_mask
|
||||
|
||||
s = p - shift
|
||||
attn_mask[-1, :, :s, :, s:, :] = True
|
||||
attn_mask[-1, :, s:, :, :s, :] = True
|
||||
attn_mask[:, -1, :, :s, :, s:] = True
|
||||
attn_mask[:, -1, :, s:, :, :s] = True
|
||||
attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
|
||||
return attn_mask
|
||||
|
||||
def forward(self, x):
|
||||
""" Forward pass of Window Multi-head Self-attention module.
|
||||
Args:
|
||||
x: input tensor with shape of [b h w c];
|
||||
attn_mask: attention mask, fill -inf where the value is True;
|
||||
Returns:
|
||||
output: tensor shape [b h w c]
|
||||
"""
|
||||
if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
|
||||
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
||||
h_windows = x.size(1)
|
||||
w_windows = x.size(2)
|
||||
# square validation
|
||||
# assert h_windows == w_windows
|
||||
|
||||
x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
|
||||
qkv = self.embedding_layer(x)
|
||||
q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
|
||||
sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
|
||||
# Adding learnable relative embedding
|
||||
sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
|
||||
# Using Attn Mask to distinguish different subwindows.
|
||||
if self.type != 'W':
|
||||
attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
|
||||
sim = sim.masked_fill_(attn_mask, float("-inf"))
|
||||
|
||||
probs = nn.functional.softmax(sim, dim=-1)
|
||||
output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
|
||||
output = rearrange(output, 'h b w p c -> b w p (h c)')
|
||||
output = self.linear(output)
|
||||
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
|
||||
|
||||
if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
|
||||
dims=(1, 2))
|
||||
return output
|
||||
|
||||
def relative_embedding(self):
|
||||
cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
|
||||
relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
|
||||
# negative is allowed
|
||||
return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
||||
""" SwinTransformer Block
|
||||
"""
|
||||
super(Block, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
assert type in ['W', 'SW']
|
||||
self.type = type
|
||||
if input_resolution <= window_size:
|
||||
self.type = 'W'
|
||||
|
||||
self.ln1 = nn.LayerNorm(input_dim)
|
||||
self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.ln2 = nn.LayerNorm(input_dim)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(input_dim, 4 * input_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(4 * input_dim, output_dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.msa(self.ln1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.ln2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class ConvTransBlock(nn.Module):
|
||||
def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
||||
""" SwinTransformer and Conv Block
|
||||
"""
|
||||
super(ConvTransBlock, self).__init__()
|
||||
self.conv_dim = conv_dim
|
||||
self.trans_dim = trans_dim
|
||||
self.head_dim = head_dim
|
||||
self.window_size = window_size
|
||||
self.drop_path = drop_path
|
||||
self.type = type
|
||||
self.input_resolution = input_resolution
|
||||
|
||||
assert self.type in ['W', 'SW']
|
||||
if self.input_resolution <= self.window_size:
|
||||
self.type = 'W'
|
||||
|
||||
self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
|
||||
self.type, self.input_resolution)
|
||||
self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
|
||||
self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
|
||||
|
||||
self.conv_block = nn.Sequential(
|
||||
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
|
||||
conv_x = self.conv_block(conv_x) + conv_x
|
||||
trans_x = Rearrange('b c h w -> b h w c')(trans_x)
|
||||
trans_x = self.trans_block(trans_x)
|
||||
trans_x = Rearrange('b h w c -> b c h w')(trans_x)
|
||||
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
|
||||
x = x + res
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SCUNet(nn.Module):
|
||||
# def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
|
||||
def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
|
||||
super(SCUNet, self).__init__()
|
||||
if config is None:
|
||||
config = [2, 2, 2, 2, 2, 2, 2]
|
||||
self.config = config
|
||||
self.dim = dim
|
||||
self.head_dim = 32
|
||||
self.window_size = 8
|
||||
|
||||
# drop path rate for each layer
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
|
||||
|
||||
self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
|
||||
|
||||
begin = 0
|
||||
self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
|
||||
'W' if not i % 2 else 'SW', input_resolution)
|
||||
for i in range(config[0])] + \
|
||||
[nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
|
||||
|
||||
begin += config[0]
|
||||
self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
|
||||
'W' if not i % 2 else 'SW', input_resolution // 2)
|
||||
for i in range(config[1])] + \
|
||||
[nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
|
||||
|
||||
begin += config[1]
|
||||
self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
||||
'W' if not i % 2 else 'SW', input_resolution // 4)
|
||||
for i in range(config[2])] + \
|
||||
[nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
|
||||
|
||||
begin += config[2]
|
||||
self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
||||
'W' if not i % 2 else 'SW', input_resolution // 8)
|
||||
for i in range(config[3])]
|
||||
|
||||
begin += config[3]
|
||||
self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
|
||||
[ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
||||
'W' if not i % 2 else 'SW', input_resolution // 4)
|
||||
for i in range(config[4])]
|
||||
|
||||
begin += config[4]
|
||||
self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
|
||||
[ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
|
||||
'W' if not i % 2 else 'SW', input_resolution // 2)
|
||||
for i in range(config[5])]
|
||||
|
||||
begin += config[5]
|
||||
self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
|
||||
[ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
|
||||
'W' if not i % 2 else 'SW', input_resolution)
|
||||
for i in range(config[6])]
|
||||
|
||||
self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
|
||||
|
||||
self.m_head = nn.Sequential(*self.m_head)
|
||||
self.m_down1 = nn.Sequential(*self.m_down1)
|
||||
self.m_down2 = nn.Sequential(*self.m_down2)
|
||||
self.m_down3 = nn.Sequential(*self.m_down3)
|
||||
self.m_body = nn.Sequential(*self.m_body)
|
||||
self.m_up3 = nn.Sequential(*self.m_up3)
|
||||
self.m_up2 = nn.Sequential(*self.m_up2)
|
||||
self.m_up1 = nn.Sequential(*self.m_up1)
|
||||
self.m_tail = nn.Sequential(*self.m_tail)
|
||||
# self.apply(self._init_weights)
|
||||
|
||||
def forward(self, x0):
|
||||
|
||||
h, w = x0.size()[-2:]
|
||||
paddingBottom = int(np.ceil(h / 64) * 64 - h)
|
||||
paddingRight = int(np.ceil(w / 64) * 64 - w)
|
||||
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
|
||||
|
||||
x1 = self.m_head(x0)
|
||||
x2 = self.m_down1(x1)
|
||||
x3 = self.m_down2(x2)
|
||||
x4 = self.m_down3(x3)
|
||||
x = self.m_body(x4)
|
||||
x = self.m_up3(x + x4)
|
||||
x = self.m_up2(x + x3)
|
||||
x = self.m_up1(x + x2)
|
||||
x = self.m_tail(x + x1)
|
||||
|
||||
x = x[..., :h, :w]
|
||||
|
||||
return x
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
@ -0,0 +1,6 @@
|
||||
import os
|
||||
from modules import paths
|
||||
|
||||
|
||||
def preload(parser):
|
||||
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,107 @@
|
||||
// Stable Diffusion WebUI - Bracket checker
|
||||
// Version 1.0
|
||||
// By Hingashi no Florin/Bwin4L
|
||||
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
|
||||
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
|
||||
|
||||
function checkBrackets(evt) {
|
||||
textArea = evt.target;
|
||||
tabName = evt.target.parentElement.parentElement.id.split("_")[0];
|
||||
counterElt = document.querySelector('gradio-app').shadowRoot.querySelector('#' + tabName + '_token_counter');
|
||||
|
||||
promptName = evt.target.parentElement.parentElement.id.includes('neg') ? ' negative' : '';
|
||||
|
||||
errorStringParen = '(' + tabName + promptName + ' prompt) - Different number of opening and closing parentheses detected.\n';
|
||||
errorStringSquare = '[' + tabName + promptName + ' prompt] - Different number of opening and closing square brackets detected.\n';
|
||||
errorStringCurly = '{' + tabName + promptName + ' prompt} - Different number of opening and closing curly brackets detected.\n';
|
||||
|
||||
openBracketRegExp = /\(/g;
|
||||
closeBracketRegExp = /\)/g;
|
||||
|
||||
openSquareBracketRegExp = /\[/g;
|
||||
closeSquareBracketRegExp = /\]/g;
|
||||
|
||||
openCurlyBracketRegExp = /\{/g;
|
||||
closeCurlyBracketRegExp = /\}/g;
|
||||
|
||||
totalOpenBracketMatches = 0;
|
||||
totalCloseBracketMatches = 0;
|
||||
totalOpenSquareBracketMatches = 0;
|
||||
totalCloseSquareBracketMatches = 0;
|
||||
totalOpenCurlyBracketMatches = 0;
|
||||
totalCloseCurlyBracketMatches = 0;
|
||||
|
||||
openBracketMatches = textArea.value.match(openBracketRegExp);
|
||||
if(openBracketMatches) {
|
||||
totalOpenBracketMatches = openBracketMatches.length;
|
||||
}
|
||||
|
||||
closeBracketMatches = textArea.value.match(closeBracketRegExp);
|
||||
if(closeBracketMatches) {
|
||||
totalCloseBracketMatches = closeBracketMatches.length;
|
||||
}
|
||||
|
||||
openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp);
|
||||
if(openSquareBracketMatches) {
|
||||
totalOpenSquareBracketMatches = openSquareBracketMatches.length;
|
||||
}
|
||||
|
||||
closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp);
|
||||
if(closeSquareBracketMatches) {
|
||||
totalCloseSquareBracketMatches = closeSquareBracketMatches.length;
|
||||
}
|
||||
|
||||
openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp);
|
||||
if(openCurlyBracketMatches) {
|
||||
totalOpenCurlyBracketMatches = openCurlyBracketMatches.length;
|
||||
}
|
||||
|
||||
closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp);
|
||||
if(closeCurlyBracketMatches) {
|
||||
totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length;
|
||||
}
|
||||
|
||||
if(totalOpenBracketMatches != totalCloseBracketMatches) {
|
||||
if(!counterElt.title.includes(errorStringParen)) {
|
||||
counterElt.title += errorStringParen;
|
||||
}
|
||||
} else {
|
||||
counterElt.title = counterElt.title.replace(errorStringParen, '');
|
||||
}
|
||||
|
||||
if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) {
|
||||
if(!counterElt.title.includes(errorStringSquare)) {
|
||||
counterElt.title += errorStringSquare;
|
||||
}
|
||||
} else {
|
||||
counterElt.title = counterElt.title.replace(errorStringSquare, '');
|
||||
}
|
||||
|
||||
if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) {
|
||||
if(!counterElt.title.includes(errorStringCurly)) {
|
||||
counterElt.title += errorStringCurly;
|
||||
}
|
||||
} else {
|
||||
counterElt.title = counterElt.title.replace(errorStringCurly, '');
|
||||
}
|
||||
|
||||
if(counterElt.title != '') {
|
||||
counterElt.style = 'color: #FF5555;';
|
||||
} else {
|
||||
counterElt.style = '';
|
||||
}
|
||||
}
|
||||
|
||||
var shadowRootLoaded = setInterval(function() {
|
||||
var shadowTextArea = document.querySelector('gradio-app').shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
|
||||
if(shadowTextArea.length < 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
clearInterval(shadowRootLoaded);
|
||||
|
||||
document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_prompt').onkeyup = checkBrackets;
|
||||
document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_neg_prompt').onkeyup = checkBrackets;
|
||||
document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_prompt').onkeyup = checkBrackets;
|
||||
document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_neg_prompt').onkeyup = checkBrackets;
|
||||
}, 1000);
|
||||
@ -0,0 +1,50 @@
|
||||
import random
|
||||
|
||||
from modules import script_callbacks, shared
|
||||
import gradio as gr
|
||||
|
||||
art_symbol = '\U0001f3a8' # 🎨
|
||||
global_prompt = None
|
||||
related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" }
|
||||
|
||||
|
||||
def roll_artist(prompt):
|
||||
allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories])
|
||||
artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
|
||||
|
||||
return prompt + ", " + artist.name if prompt != '' else artist.name
|
||||
|
||||
|
||||
def add_roll_button(prompt):
|
||||
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
||||
|
||||
roll.click(
|
||||
fn=roll_artist,
|
||||
_js="update_txt2img_tokens",
|
||||
inputs=[
|
||||
prompt,
|
||||
],
|
||||
outputs=[
|
||||
prompt,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def after_component(component, **kwargs):
|
||||
global global_prompt
|
||||
|
||||
elem_id = kwargs.get('elem_id', None)
|
||||
if elem_id not in related_ids:
|
||||
return
|
||||
|
||||
if elem_id == "txt2img_prompt":
|
||||
global_prompt = component
|
||||
elif elem_id == "txt2img_clear_prompt":
|
||||
add_roll_button(global_prompt)
|
||||
elif elem_id == "img2img_prompt":
|
||||
global_prompt = component
|
||||
elif elem_id == "img2img_clear_prompt":
|
||||
add_roll_button(global_prompt)
|
||||
|
||||
|
||||
script_callbacks.on_after_component(after_component)
|
||||
@ -0,0 +1,9 @@
|
||||
<div>
|
||||
<a href="/docs">API</a>
|
||||
•
|
||||
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a>
|
||||
•
|
||||
<a href="https://gradio.app">Gradio</a>
|
||||
•
|
||||
<a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a>
|
||||
</div>
|
||||
@ -0,0 +1,392 @@
|
||||
<style>
|
||||
#licenses h2 {font-size: 1.2em; font-weight: bold; margin-bottom: 0.2em;}
|
||||
#licenses small {font-size: 0.95em; opacity: 0.85;}
|
||||
#licenses pre { margin: 1em 0 2em 0;}
|
||||
</style>
|
||||
|
||||
<h2><a href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">CodeFormer</a></h2>
|
||||
<small>Parts of CodeFormer code had to be copied to be compatible with GFPGAN.</small>
|
||||
<pre>
|
||||
S-Lab License 1.0
|
||||
|
||||
Copyright 2022 S-Lab
|
||||
|
||||
Redistribution and use for non-commercial purpose in source and
|
||||
binary forms, with or without modification, are permitted provided
|
||||
that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
In the event that redistribution and/or use for commercial purpose in
|
||||
source or binary forms, with or without modification is required,
|
||||
please contact the contributor(s) of the work.
|
||||
</pre>
|
||||
|
||||
|
||||
<h2><a href="https://github.com/victorca25/iNNfer/blob/main/LICENSE">ESRGAN</a></h2>
|
||||
<small>Code for architecture and reading models copied.</small>
|
||||
<pre>
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2021 victorca25
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE">Real-ESRGAN</a></h2>
|
||||
<small>Some code is copied to support ESRGAN models.</small>
|
||||
<pre>
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2021, Xintao Wang
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE">InvokeAI</a></h2>
|
||||
<small>Some code for compatibility with OSX is taken from lstein's repository.</small>
|
||||
<pre>
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 InvokeAI Team
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/Hafiidz/latent-diffusion/blob/main/LICENSE">LDSR</a></h2>
|
||||
<small>Code added by contirubtors, most likely copied from this repository.</small>
|
||||
<pre>
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/pharmapsychotic/clip-interrogator/blob/main/LICENSE">CLIP Interrogator</a></h2>
|
||||
<small>Some small amounts of code borrowed and reworked.</small>
|
||||
<pre>
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 pharmapsychotic
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
|
||||
<small>Code added by contirubtors, most likely copied from this repository.</small>
|
||||
|
||||
<pre>
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [2021] [SwinIR Authors]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
</pre>
|
||||
|
||||
@ -0,0 +1,177 @@
|
||||
|
||||
contextMenuInit = function(){
|
||||
let eventListenerApplied=false;
|
||||
let menuSpecs = new Map();
|
||||
|
||||
const uid = function(){
|
||||
return Date.now().toString(36) + Math.random().toString(36).substr(2);
|
||||
}
|
||||
|
||||
function showContextMenu(event,element,menuEntries){
|
||||
let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
|
||||
let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
|
||||
|
||||
let oldMenu = gradioApp().querySelector('#context-menu')
|
||||
if(oldMenu){
|
||||
oldMenu.remove()
|
||||
}
|
||||
|
||||
let tabButton = uiCurrentTab
|
||||
let baseStyle = window.getComputedStyle(tabButton)
|
||||
|
||||
const contextMenu = document.createElement('nav')
|
||||
contextMenu.id = "context-menu"
|
||||
contextMenu.style.background = baseStyle.background
|
||||
contextMenu.style.color = baseStyle.color
|
||||
contextMenu.style.fontFamily = baseStyle.fontFamily
|
||||
contextMenu.style.top = posy+'px'
|
||||
contextMenu.style.left = posx+'px'
|
||||
|
||||
|
||||
|
||||
const contextMenuList = document.createElement('ul')
|
||||
contextMenuList.className = 'context-menu-items';
|
||||
contextMenu.append(contextMenuList);
|
||||
|
||||
menuEntries.forEach(function(entry){
|
||||
let contextMenuEntry = document.createElement('a')
|
||||
contextMenuEntry.innerHTML = entry['name']
|
||||
contextMenuEntry.addEventListener("click", function(e) {
|
||||
entry['func']();
|
||||
})
|
||||
contextMenuList.append(contextMenuEntry);
|
||||
|
||||
})
|
||||
|
||||
gradioApp().getRootNode().appendChild(contextMenu)
|
||||
|
||||
let menuWidth = contextMenu.offsetWidth + 4;
|
||||
let menuHeight = contextMenu.offsetHeight + 4;
|
||||
|
||||
let windowWidth = window.innerWidth;
|
||||
let windowHeight = window.innerHeight;
|
||||
|
||||
if ( (windowWidth - posx) < menuWidth ) {
|
||||
contextMenu.style.left = windowWidth - menuWidth + "px";
|
||||
}
|
||||
|
||||
if ( (windowHeight - posy) < menuHeight ) {
|
||||
contextMenu.style.top = windowHeight - menuHeight + "px";
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
|
||||
|
||||
currentItems = menuSpecs.get(targetElementSelector)
|
||||
|
||||
if(!currentItems){
|
||||
currentItems = []
|
||||
menuSpecs.set(targetElementSelector,currentItems);
|
||||
}
|
||||
let newItem = {'id':targetElementSelector+'_'+uid(),
|
||||
'name':entryName,
|
||||
'func':entryFunction,
|
||||
'isNew':true}
|
||||
|
||||
currentItems.push(newItem)
|
||||
return newItem['id']
|
||||
}
|
||||
|
||||
function removeContextMenuOption(uid){
|
||||
menuSpecs.forEach(function(v,k) {
|
||||
let index = -1
|
||||
v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
|
||||
if(index>=0){
|
||||
v.splice(index, 1);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
function addContextMenuEventListener(){
|
||||
if(eventListenerApplied){
|
||||
return;
|
||||
}
|
||||
gradioApp().addEventListener("click", function(e) {
|
||||
let source = e.composedPath()[0]
|
||||
if(source.id && source.id.indexOf('check_progress')>-1){
|
||||
return
|
||||
}
|
||||
|
||||
let oldMenu = gradioApp().querySelector('#context-menu')
|
||||
if(oldMenu){
|
||||
oldMenu.remove()
|
||||
}
|
||||
});
|
||||
gradioApp().addEventListener("contextmenu", function(e) {
|
||||
let oldMenu = gradioApp().querySelector('#context-menu')
|
||||
if(oldMenu){
|
||||
oldMenu.remove()
|
||||
}
|
||||
menuSpecs.forEach(function(v,k) {
|
||||
if(e.composedPath()[0].matches(k)){
|
||||
showContextMenu(e,e.composedPath()[0],v)
|
||||
e.preventDefault()
|
||||
return
|
||||
}
|
||||
})
|
||||
});
|
||||
eventListenerApplied=true
|
||||
|
||||
}
|
||||
|
||||
return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]
|
||||
}
|
||||
|
||||
initResponse = contextMenuInit();
|
||||
appendContextMenuOption = initResponse[0];
|
||||
removeContextMenuOption = initResponse[1];
|
||||
addContextMenuEventListener = initResponse[2];
|
||||
|
||||
(function(){
|
||||
//Start example Context Menu Items
|
||||
let generateOnRepeat = function(genbuttonid,interruptbuttonid){
|
||||
let genbutton = gradioApp().querySelector(genbuttonid);
|
||||
let interruptbutton = gradioApp().querySelector(interruptbuttonid);
|
||||
if(!interruptbutton.offsetParent){
|
||||
genbutton.click();
|
||||
}
|
||||
clearInterval(window.generateOnRepeatInterval)
|
||||
window.generateOnRepeatInterval = setInterval(function(){
|
||||
if(!interruptbutton.offsetParent){
|
||||
genbutton.click();
|
||||
}
|
||||
},
|
||||
500)
|
||||
}
|
||||
|
||||
appendContextMenuOption('#txt2img_generate','Generate forever',function(){
|
||||
generateOnRepeat('#txt2img_generate','#txt2img_interrupt');
|
||||
})
|
||||
appendContextMenuOption('#img2img_generate','Generate forever',function(){
|
||||
generateOnRepeat('#img2img_generate','#img2img_interrupt');
|
||||
})
|
||||
|
||||
let cancelGenerateForever = function(){
|
||||
clearInterval(window.generateOnRepeatInterval)
|
||||
}
|
||||
|
||||
appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
|
||||
appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever)
|
||||
appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
|
||||
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
|
||||
|
||||
appendContextMenuOption('#roll','Roll three',
|
||||
function(){
|
||||
let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
|
||||
setTimeout(function(){rollbutton.click()},100)
|
||||
setTimeout(function(){rollbutton.click()},200)
|
||||
setTimeout(function(){rollbutton.click()},300)
|
||||
}
|
||||
)
|
||||
})();
|
||||
//End example Context Menu Items
|
||||
|
||||
onUiUpdate(function(){
|
||||
addContextMenuEventListener()
|
||||
});
|
||||
@ -0,0 +1,75 @@
|
||||
addEventListener('keydown', (event) => {
|
||||
let target = event.originalTarget || event.composedPath()[0];
|
||||
if (!target.matches("#toprow textarea.gr-text-input[placeholder]")) return;
|
||||
if (! (event.metaKey || event.ctrlKey)) return;
|
||||
|
||||
|
||||
let plus = "ArrowUp"
|
||||
let minus = "ArrowDown"
|
||||
if (event.key != plus && event.key != minus) return;
|
||||
|
||||
let selectionStart = target.selectionStart;
|
||||
let selectionEnd = target.selectionEnd;
|
||||
// If the user hasn't selected anything, let's select their current parenthesis block
|
||||
if (selectionStart === selectionEnd) {
|
||||
// Find opening parenthesis around current cursor
|
||||
const before = target.value.substring(0, selectionStart);
|
||||
let beforeParen = before.lastIndexOf("(");
|
||||
if (beforeParen == -1) return;
|
||||
let beforeParenClose = before.lastIndexOf(")");
|
||||
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
|
||||
beforeParen = before.lastIndexOf("(", beforeParen - 1);
|
||||
beforeParenClose = before.lastIndexOf(")", beforeParenClose - 1);
|
||||
}
|
||||
|
||||
// Find closing parenthesis around current cursor
|
||||
const after = target.value.substring(selectionStart);
|
||||
let afterParen = after.indexOf(")");
|
||||
if (afterParen == -1) return;
|
||||
let afterParenOpen = after.indexOf("(");
|
||||
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
|
||||
afterParen = after.indexOf(")", afterParen + 1);
|
||||
afterParenOpen = after.indexOf("(", afterParenOpen + 1);
|
||||
}
|
||||
if (beforeParen === -1 || afterParen === -1) return;
|
||||
|
||||
// Set the selection to the text between the parenthesis
|
||||
const parenContent = target.value.substring(beforeParen + 1, selectionStart + afterParen);
|
||||
const lastColon = parenContent.lastIndexOf(":");
|
||||
selectionStart = beforeParen + 1;
|
||||
selectionEnd = selectionStart + lastColon;
|
||||
target.setSelectionRange(selectionStart, selectionEnd);
|
||||
}
|
||||
|
||||
event.preventDefault();
|
||||
|
||||
if (selectionStart == 0 || target.value[selectionStart - 1] != "(") {
|
||||
target.value = target.value.slice(0, selectionStart) +
|
||||
"(" + target.value.slice(selectionStart, selectionEnd) + ":1.0)" +
|
||||
target.value.slice(selectionEnd);
|
||||
|
||||
target.focus();
|
||||
target.selectionStart = selectionStart + 1;
|
||||
target.selectionEnd = selectionEnd + 1;
|
||||
|
||||
} else {
|
||||
end = target.value.slice(selectionEnd + 1).indexOf(")") + 1;
|
||||
weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end));
|
||||
if (isNaN(weight)) return;
|
||||
if (event.key == minus) weight -= 0.1;
|
||||
if (event.key == plus) weight += 0.1;
|
||||
|
||||
weight = parseFloat(weight.toPrecision(12));
|
||||
|
||||
target.value = target.value.slice(0, selectionEnd + 1) +
|
||||
weight +
|
||||
target.value.slice(selectionEnd + 1 + end - 1);
|
||||
|
||||
target.focus();
|
||||
target.selectionStart = selectionStart;
|
||||
target.selectionEnd = selectionEnd;
|
||||
}
|
||||
// Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure its
|
||||
// internal Svelte data binding remains in sync.
|
||||
target.dispatchEvent(new Event("input", { bubbles: true }));
|
||||
});
|
||||
@ -0,0 +1,35 @@
|
||||
|
||||
function extensions_apply(_, _){
|
||||
disable = []
|
||||
update = []
|
||||
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
||||
if(x.name.startsWith("enable_") && ! x.checked)
|
||||
disable.push(x.name.substr(7))
|
||||
|
||||
if(x.name.startsWith("update_") && x.checked)
|
||||
update.push(x.name.substr(7))
|
||||
})
|
||||
|
||||
restart_reload()
|
||||
|
||||
return [JSON.stringify(disable), JSON.stringify(update)]
|
||||
}
|
||||
|
||||
function extensions_check(){
|
||||
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
||||
x.innerHTML = "Loading..."
|
||||
})
|
||||
|
||||
return []
|
||||
}
|
||||
|
||||
function install_extension_from_index(button, url){
|
||||
button.disabled = "disabled"
|
||||
button.value = "Installing..."
|
||||
|
||||
textarea = gradioApp().querySelector('#extension_to_install textarea')
|
||||
textarea.value = url
|
||||
textarea.dispatchEvent(new Event("input", { bubbles: true }))
|
||||
|
||||
gradioApp().querySelector('#install_extension_button').click()
|
||||
}
|
||||
@ -0,0 +1,33 @@
|
||||
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
|
||||
|
||||
let txt2img_gallery, img2img_gallery, modal = undefined;
|
||||
onUiUpdate(function(){
|
||||
if (!txt2img_gallery) {
|
||||
txt2img_gallery = attachGalleryListeners("txt2img")
|
||||
}
|
||||
if (!img2img_gallery) {
|
||||
img2img_gallery = attachGalleryListeners("img2img")
|
||||
}
|
||||
if (!modal) {
|
||||
modal = gradioApp().getElementById('lightboxModal')
|
||||
modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] });
|
||||
}
|
||||
});
|
||||
|
||||
let modalObserver = new MutationObserver(function(mutations) {
|
||||
mutations.forEach(function(mutationRecord) {
|
||||
let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText
|
||||
if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img')
|
||||
gradioApp().getElementById(selectedTab+"_generation_info_button").click()
|
||||
});
|
||||
});
|
||||
|
||||
function attachGalleryListeners(tab_name) {
|
||||
gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
|
||||
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
|
||||
gallery?.addEventListener('keydown', (e) => {
|
||||
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
|
||||
gradioApp().getElementById(tab_name+"_generation_info_button").click()
|
||||
});
|
||||
return gallery;
|
||||
}
|
||||
@ -0,0 +1,19 @@
|
||||
window.onload = (function(){
|
||||
window.addEventListener('drop', e => {
|
||||
const target = e.composedPath()[0];
|
||||
const idx = selected_gallery_index();
|
||||
if (target.placeholder.indexOf("Prompt") == -1) return;
|
||||
|
||||
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
|
||||
|
||||
e.stopPropagation();
|
||||
e.preventDefault();
|
||||
const imgParent = gradioApp().getElementById(prompt_target);
|
||||
const files = e.dataTransfer.files;
|
||||
const fileInput = imgParent.querySelector('input[type="file"]');
|
||||
if ( fileInput ) {
|
||||
fileInput.files = files;
|
||||
fileInput.dispatchEvent(new Event('change'));
|
||||
}
|
||||
});
|
||||
});
|
||||
@ -0,0 +1,167 @@
|
||||
|
||||
// localization = {} -- the dict with translations is created by the backend
|
||||
|
||||
ignore_ids_for_localization={
|
||||
setting_sd_hypernetwork: 'OPTION',
|
||||
setting_sd_model_checkpoint: 'OPTION',
|
||||
setting_realesrgan_enabled_models: 'OPTION',
|
||||
modelmerger_primary_model_name: 'OPTION',
|
||||
modelmerger_secondary_model_name: 'OPTION',
|
||||
modelmerger_tertiary_model_name: 'OPTION',
|
||||
train_embedding: 'OPTION',
|
||||
train_hypernetwork: 'OPTION',
|
||||
txt2img_style_index: 'OPTION',
|
||||
txt2img_style2_index: 'OPTION',
|
||||
img2img_style_index: 'OPTION',
|
||||
img2img_style2_index: 'OPTION',
|
||||
setting_random_artist_categories: 'SPAN',
|
||||
setting_face_restoration_model: 'SPAN',
|
||||
setting_realesrgan_enabled_models: 'SPAN',
|
||||
extras_upscaler_1: 'SPAN',
|
||||
extras_upscaler_2: 'SPAN',
|
||||
}
|
||||
|
||||
re_num = /^[\.\d]+$/
|
||||
re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
|
||||
|
||||
original_lines = {}
|
||||
translated_lines = {}
|
||||
|
||||
function textNodesUnder(el){
|
||||
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
|
||||
while(n=walk.nextNode()) a.push(n);
|
||||
return a;
|
||||
}
|
||||
|
||||
function canBeTranslated(node, text){
|
||||
if(! text) return false;
|
||||
if(! node.parentElement) return false;
|
||||
|
||||
parentType = node.parentElement.nodeName
|
||||
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
|
||||
|
||||
if (parentType=='OPTION' || parentType=='SPAN'){
|
||||
pnode = node
|
||||
for(var level=0; level<4; level++){
|
||||
pnode = pnode.parentElement
|
||||
if(! pnode) break;
|
||||
|
||||
if(ignore_ids_for_localization[pnode.id] == parentType) return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(re_num.test(text)) return false;
|
||||
if(re_emoji.test(text)) return false;
|
||||
return true
|
||||
}
|
||||
|
||||
function getTranslation(text){
|
||||
if(! text) return undefined
|
||||
|
||||
if(translated_lines[text] === undefined){
|
||||
original_lines[text] = 1
|
||||
}
|
||||
|
||||
tl = localization[text]
|
||||
if(tl !== undefined){
|
||||
translated_lines[tl] = 1
|
||||
}
|
||||
|
||||
return tl
|
||||
}
|
||||
|
||||
function processTextNode(node){
|
||||
text = node.textContent.trim()
|
||||
|
||||
if(! canBeTranslated(node, text)) return
|
||||
|
||||
tl = getTranslation(text)
|
||||
if(tl !== undefined){
|
||||
node.textContent = tl
|
||||
}
|
||||
}
|
||||
|
||||
function processNode(node){
|
||||
if(node.nodeType == 3){
|
||||
processTextNode(node)
|
||||
return
|
||||
}
|
||||
|
||||
if(node.title){
|
||||
tl = getTranslation(node.title)
|
||||
if(tl !== undefined){
|
||||
node.title = tl
|
||||
}
|
||||
}
|
||||
|
||||
if(node.placeholder){
|
||||
tl = getTranslation(node.placeholder)
|
||||
if(tl !== undefined){
|
||||
node.placeholder = tl
|
||||
}
|
||||
}
|
||||
|
||||
textNodesUnder(node).forEach(function(node){
|
||||
processTextNode(node)
|
||||
})
|
||||
}
|
||||
|
||||
function dumpTranslations(){
|
||||
dumped = {}
|
||||
if (localization.rtl) {
|
||||
dumped.rtl = true
|
||||
}
|
||||
|
||||
Object.keys(original_lines).forEach(function(text){
|
||||
if(dumped[text] !== undefined) return
|
||||
|
||||
dumped[text] = localization[text] || text
|
||||
})
|
||||
|
||||
return dumped
|
||||
}
|
||||
|
||||
onUiUpdate(function(m){
|
||||
m.forEach(function(mutation){
|
||||
mutation.addedNodes.forEach(function(node){
|
||||
processNode(node)
|
||||
})
|
||||
});
|
||||
})
|
||||
|
||||
|
||||
document.addEventListener("DOMContentLoaded", function() {
|
||||
processNode(gradioApp())
|
||||
|
||||
if (localization.rtl) { // if the language is from right to left,
|
||||
(new MutationObserver((mutations, observer) => { // wait for the style to load
|
||||
mutations.forEach(mutation => {
|
||||
mutation.addedNodes.forEach(node => {
|
||||
if (node.tagName === 'STYLE') {
|
||||
observer.disconnect();
|
||||
|
||||
for (const x of node.sheet.rules) { // find all rtl media rules
|
||||
if (Array.from(x.media || []).includes('rtl')) {
|
||||
x.media.appendMedium('all'); // enable them
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
})).observe(gradioApp(), { childList: true });
|
||||
}
|
||||
})
|
||||
|
||||
function download_localization() {
|
||||
text = JSON.stringify(dumpTranslations(), null, 4)
|
||||
|
||||
var element = document.createElement('a');
|
||||
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
|
||||
element.setAttribute('download', "localization.json");
|
||||
element.style.display = 'none';
|
||||
document.body.appendChild(element);
|
||||
|
||||
element.click();
|
||||
|
||||
document.body.removeChild(element);
|
||||
}
|
||||
@ -0,0 +1,8 @@
|
||||
|
||||
|
||||
function start_training_textual_inversion(){
|
||||
requestProgress('ti')
|
||||
gradioApp().querySelector('#ti_error').innerHTML=''
|
||||
|
||||
return args_to_array(arguments)
|
||||
}
|
||||
Binary file not shown.
@ -0,0 +1,462 @@
|
||||
import base64
|
||||
import io
|
||||
import time
|
||||
import datetime
|
||||
import uvicorn
|
||||
from threading import Lock
|
||||
from io import BytesIO
|
||||
from gradio.processing_utils import decode_base64_to_file
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from secrets import compare_digest
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import sd_samplers, deepbooru, sd_hijack
|
||||
from modules.api.models import *
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.extras import run_extras, run_pnginfo
|
||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||
from modules.textual_inversion.preprocess import preprocess
|
||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||
from PIL import PngImagePlugin,Image
|
||||
from modules.sd_models import checkpoints_list, find_checkpoint_config
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
from typing import List
|
||||
|
||||
def upscaler_to_index(name: str):
|
||||
try:
|
||||
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
|
||||
|
||||
def validate_sampler_name(name):
|
||||
config = sd_samplers.all_samplers_map.get(name, None)
|
||||
if config is None:
|
||||
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||
|
||||
return name
|
||||
|
||||
def setUpscalers(req: dict):
|
||||
reqDict = vars(req)
|
||||
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
|
||||
reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
|
||||
reqDict.pop('upscaler_1')
|
||||
reqDict.pop('upscaler_2')
|
||||
return reqDict
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";")[1].split(",")[1]
|
||||
return Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
|
||||
def encode_pil_to_base64(image):
|
||||
with io.BytesIO() as output_bytes:
|
||||
|
||||
# Copy any text-only metadata
|
||||
use_metadata = False
|
||||
metadata = PngImagePlugin.PngInfo()
|
||||
for key, value in image.info.items():
|
||||
if isinstance(key, str) and isinstance(value, str):
|
||||
metadata.add_text(key, value)
|
||||
use_metadata = True
|
||||
|
||||
image.save(
|
||||
output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
return base64.b64encode(bytes_data)
|
||||
|
||||
def api_middleware(app: FastAPI):
|
||||
@app.middleware("http")
|
||||
async def log_and_time(req: Request, call_next):
|
||||
ts = time.time()
|
||||
res: Response = await call_next(req)
|
||||
duration = str(round(time.time() - ts, 4))
|
||||
res.headers["X-Process-Time"] = duration
|
||||
endpoint = req.scope.get('path', 'err')
|
||||
if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
|
||||
print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
|
||||
t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
|
||||
code = res.status_code,
|
||||
ver = req.scope.get('http_version', '0.0'),
|
||||
cli = req.scope.get('client', ('0:0.0.0', 0))[0],
|
||||
prot = req.scope.get('scheme', 'err'),
|
||||
method = req.scope.get('method', 'err'),
|
||||
endpoint = endpoint,
|
||||
duration = duration,
|
||||
))
|
||||
return res
|
||||
|
||||
|
||||
class Api:
|
||||
def __init__(self, app: FastAPI, queue_lock: Lock):
|
||||
if shared.cmd_opts.api_auth:
|
||||
self.credentials = dict()
|
||||
for auth in shared.cmd_opts.api_auth.split(","):
|
||||
user, password = auth.split(":")
|
||||
self.credentials[user] = password
|
||||
|
||||
self.router = APIRouter()
|
||||
self.app = app
|
||||
self.queue_lock = queue_lock
|
||||
api_middleware(self.app)
|
||||
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
|
||||
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
|
||||
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
|
||||
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
|
||||
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
||||
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
||||
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
|
||||
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
|
||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
|
||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
|
||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
|
||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
|
||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
|
||||
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
|
||||
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
|
||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
|
||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
|
||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
|
||||
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
|
||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
||||
|
||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||
if shared.cmd_opts.api_auth:
|
||||
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
|
||||
return self.app.add_api_route(path, endpoint, **kwargs)
|
||||
|
||||
def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
|
||||
if credentials.username in self.credentials:
|
||||
if compare_digest(credentials.password, self.credentials[credentials.username]):
|
||||
return True
|
||||
|
||||
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
|
||||
|
||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
||||
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
|
||||
"do_not_save_samples": True,
|
||||
"do_not_save_grid": True
|
||||
}
|
||||
)
|
||||
if populate.sampler_name:
|
||||
populate.sampler_index = None # prevent a warning later on
|
||||
|
||||
with self.queue_lock:
|
||||
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))
|
||||
|
||||
shared.state.begin()
|
||||
processed = process_images(p)
|
||||
shared.state.end()
|
||||
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images))
|
||||
|
||||
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||
|
||||
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
|
||||
init_images = img2imgreq.init_images
|
||||
if init_images is None:
|
||||
raise HTTPException(status_code=404, detail="Init image not found")
|
||||
|
||||
mask = img2imgreq.mask
|
||||
if mask:
|
||||
mask = decode_base64_to_image(mask)
|
||||
|
||||
populate = img2imgreq.copy(update={ # Override __init__ params
|
||||
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
|
||||
"do_not_save_samples": True,
|
||||
"do_not_save_grid": True,
|
||||
"mask": mask
|
||||
}
|
||||
)
|
||||
if populate.sampler_name:
|
||||
populate.sampler_index = None # prevent a warning later on
|
||||
|
||||
args = vars(populate)
|
||||
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
|
||||
|
||||
with self.queue_lock:
|
||||
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
|
||||
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
||||
|
||||
shared.state.begin()
|
||||
processed = process_images(p)
|
||||
shared.state.end()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images))
|
||||
|
||||
if not img2imgreq.include_init_images:
|
||||
img2imgreq.init_images = None
|
||||
img2imgreq.mask = None
|
||||
|
||||
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
||||
|
||||
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
|
||||
reqDict = setUpscalers(req)
|
||||
|
||||
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
||||
|
||||
with self.queue_lock:
|
||||
result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
|
||||
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
||||
|
||||
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
|
||||
reqDict = setUpscalers(req)
|
||||
|
||||
def prepareFiles(file):
|
||||
file = decode_base64_to_file(file.data, file_path=file.name)
|
||||
file.orig_name = file.name
|
||||
return file
|
||||
|
||||
reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
|
||||
reqDict.pop('imageList')
|
||||
|
||||
with self.queue_lock:
|
||||
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
|
||||
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||
|
||||
def pnginfoapi(self, req: PNGInfoRequest):
|
||||
if(not req.image.strip()):
|
||||
return PNGInfoResponse(info="")
|
||||
|
||||
result = run_pnginfo(decode_base64_to_image(req.image.strip()))
|
||||
|
||||
return PNGInfoResponse(info=result[1])
|
||||
|
||||
def progressapi(self, req: ProgressRequest = Depends()):
|
||||
# copy from check_progress_call of ui.py
|
||||
|
||||
if shared.state.job_count == 0:
|
||||
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())
|
||||
|
||||
# avoid dividing zero
|
||||
progress = 0.01
|
||||
|
||||
if shared.state.job_count > 0:
|
||||
progress += shared.state.job_no / shared.state.job_count
|
||||
if shared.state.sampling_steps > 0:
|
||||
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
|
||||
|
||||
time_since_start = time.time() - shared.state.time_start
|
||||
eta = (time_since_start/progress)
|
||||
eta_relative = eta-time_since_start
|
||||
|
||||
progress = min(progress, 1)
|
||||
|
||||
shared.state.set_current_image()
|
||||
|
||||
current_image = None
|
||||
if shared.state.current_image and not req.skip_current_image:
|
||||
current_image = encode_pil_to_base64(shared.state.current_image)
|
||||
|
||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
|
||||
|
||||
def interrogateapi(self, interrogatereq: InterrogateRequest):
|
||||
image_b64 = interrogatereq.image
|
||||
if image_b64 is None:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
img = decode_base64_to_image(image_b64)
|
||||
img = img.convert('RGB')
|
||||
|
||||
# Override object param
|
||||
with self.queue_lock:
|
||||
if interrogatereq.model == "clip":
|
||||
processed = shared.interrogator.interrogate(img)
|
||||
elif interrogatereq.model == "deepdanbooru":
|
||||
processed = deepbooru.model.tag(img)
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
return InterrogateResponse(caption=processed)
|
||||
|
||||
def interruptapi(self):
|
||||
shared.state.interrupt()
|
||||
|
||||
return {}
|
||||
|
||||
def skip(self):
|
||||
shared.state.skip()
|
||||
|
||||
def get_config(self):
|
||||
options = {}
|
||||
for key in shared.opts.data.keys():
|
||||
metadata = shared.opts.data_labels.get(key)
|
||||
if(metadata is not None):
|
||||
options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
|
||||
else:
|
||||
options.update({key: shared.opts.data.get(key, None)})
|
||||
|
||||
return options
|
||||
|
||||
def set_config(self, req: Dict[str, Any]):
|
||||
for k, v in req.items():
|
||||
shared.opts.set(k, v)
|
||||
|
||||
shared.opts.save(shared.config_filename)
|
||||
return
|
||||
|
||||
def get_cmd_flags(self):
|
||||
return vars(shared.cmd_opts)
|
||||
|
||||
def get_samplers(self):
|
||||
return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
|
||||
|
||||
def get_upscalers(self):
|
||||
upscalers = []
|
||||
|
||||
for upscaler in shared.sd_upscalers:
|
||||
u = upscaler.scaler
|
||||
upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
|
||||
|
||||
return upscalers
|
||||
|
||||
def get_sd_models(self):
|
||||
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
|
||||
|
||||
def get_hypernetworks(self):
|
||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||
|
||||
def get_face_restorers(self):
|
||||
return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
|
||||
|
||||
def get_realesrgan_models(self):
|
||||
return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
|
||||
|
||||
def get_prompt_styles(self):
|
||||
styleList = []
|
||||
for k in shared.prompt_styles.styles:
|
||||
style = shared.prompt_styles.styles[k]
|
||||
styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
|
||||
|
||||
return styleList
|
||||
|
||||
def get_artists_categories(self):
|
||||
return shared.artist_db.cats
|
||||
|
||||
def get_artists(self):
|
||||
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
|
||||
|
||||
def get_embeddings(self):
|
||||
db = sd_hijack.model_hijack.embedding_db
|
||||
|
||||
def convert_embedding(embedding):
|
||||
return {
|
||||
"step": embedding.step,
|
||||
"sd_checkpoint": embedding.sd_checkpoint,
|
||||
"sd_checkpoint_name": embedding.sd_checkpoint_name,
|
||||
"shape": embedding.shape,
|
||||
"vectors": embedding.vectors,
|
||||
}
|
||||
|
||||
def convert_embeddings(embeddings):
|
||||
return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
|
||||
|
||||
return {
|
||||
"loaded": convert_embeddings(db.word_embeddings),
|
||||
"skipped": convert_embeddings(db.skipped_embeddings),
|
||||
}
|
||||
|
||||
def refresh_checkpoints(self):
|
||||
shared.refresh_checkpoints()
|
||||
|
||||
def create_embedding(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
filename = create_embedding(**args) # create empty embedding
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
||||
shared.state.end()
|
||||
return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "create embedding error: {error}".format(error = e))
|
||||
|
||||
def create_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
filename = create_hypernetwork(**args) # create empty embedding
|
||||
shared.state.end()
|
||||
return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))
|
||||
|
||||
def preprocess(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
||||
shared.state.end()
|
||||
return PreprocessResponse(info = 'preprocess complete')
|
||||
except KeyError as e:
|
||||
shared.state.end()
|
||||
return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
|
||||
except FileNotFoundError as e:
|
||||
shared.state.end()
|
||||
return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
|
||||
|
||||
def train_embedding(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||
error = None
|
||||
filename = ''
|
||||
if not apply_optimizations:
|
||||
sd_hijack.undo_optimizations()
|
||||
try:
|
||||
embedding, filename = train_embedding(**args) # can take a long time to complete
|
||||
except Exception as e:
|
||||
error = e
|
||||
finally:
|
||||
if not apply_optimizations:
|
||||
sd_hijack.apply_optimizations()
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
|
||||
except AssertionError as msg:
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))
|
||||
|
||||
def train_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
initial_hypernetwork = shared.loaded_hypernetwork
|
||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||
error = None
|
||||
filename = ''
|
||||
if not apply_optimizations:
|
||||
sd_hijack.undo_optimizations()
|
||||
try:
|
||||
hypernetwork, filename = train_hypernetwork(*args)
|
||||
except Exception as e:
|
||||
error = e
|
||||
finally:
|
||||
shared.loaded_hypernetwork = initial_hypernetwork
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
if not apply_optimizations:
|
||||
sd_hijack.apply_optimizations()
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
|
||||
except AssertionError as msg:
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "train embedding error: {error}".format(error = error))
|
||||
|
||||
def launch(self, server_name, port):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(self.app, host=server_name, port=port)
|
||||
@ -0,0 +1,261 @@
|
||||
import inspect
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from typing import Any, Optional
|
||||
from typing_extensions import Literal
|
||||
from inflection import underscore
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
||||
from modules.shared import sd_upscalers, opts, parser
|
||||
from typing import Dict, List
|
||||
|
||||
API_NOT_ALLOWED = [
|
||||
"self",
|
||||
"kwargs",
|
||||
"sd_model",
|
||||
"outpath_samples",
|
||||
"outpath_grids",
|
||||
"sampler_index",
|
||||
"do_not_save_samples",
|
||||
"do_not_save_grid",
|
||||
"extra_generation_params",
|
||||
"overlay_images",
|
||||
"do_not_reload_embeddings",
|
||||
"seed_enable_extras",
|
||||
"prompt_for_display",
|
||||
"sampler_noise_scheduler_override",
|
||||
"ddim_discretize"
|
||||
]
|
||||
|
||||
class ModelDef(BaseModel):
|
||||
"""Assistance Class for Pydantic Dynamic Model Generation"""
|
||||
|
||||
field: str
|
||||
field_alias: str
|
||||
field_type: Any
|
||||
field_value: Any
|
||||
field_exclude: bool = False
|
||||
|
||||
|
||||
class PydanticModelGenerator:
|
||||
"""
|
||||
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
|
||||
source_data is a snapshot of the default values produced by the class
|
||||
params are the names of the actual keys required by __init__
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = None,
|
||||
class_instance = None,
|
||||
additional_fields = None,
|
||||
):
|
||||
def field_type_generator(k, v):
|
||||
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
||||
# print(k, v.annotation, v.default)
|
||||
field_type = v.annotation
|
||||
|
||||
return Optional[field_type]
|
||||
|
||||
def merge_class_params(class_):
|
||||
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
|
||||
parameters = {}
|
||||
for classes in all_classes:
|
||||
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
||||
return parameters
|
||||
|
||||
|
||||
self._model_name = model_name
|
||||
self._class_data = merge_class_params(class_instance)
|
||||
|
||||
self._model_def = [
|
||||
ModelDef(
|
||||
field=underscore(k),
|
||||
field_alias=k,
|
||||
field_type=field_type_generator(k, v),
|
||||
field_value=v.default
|
||||
)
|
||||
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||
]
|
||||
|
||||
for fields in additional_fields:
|
||||
self._model_def.append(ModelDef(
|
||||
field=underscore(fields["key"]),
|
||||
field_alias=fields["key"],
|
||||
field_type=fields["type"],
|
||||
field_value=fields["default"],
|
||||
field_exclude=fields["exclude"] if "exclude" in fields else False))
|
||||
|
||||
def generate_model(self):
|
||||
"""
|
||||
Creates a pydantic BaseModel
|
||||
from the json and overrides provided at initialization
|
||||
"""
|
||||
fields = {
|
||||
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
|
||||
}
|
||||
DynamicModel = create_model(self._model_name, **fields)
|
||||
DynamicModel.__config__.allow_population_by_field_name = True
|
||||
DynamicModel.__config__.allow_mutation = True
|
||||
return DynamicModel
|
||||
|
||||
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingTxt2Img",
|
||||
StableDiffusionProcessingTxt2Img,
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}]
|
||||
).generate_model()
|
||||
|
||||
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingImg2Img",
|
||||
StableDiffusionProcessingImg2Img,
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
|
||||
).generate_model()
|
||||
|
||||
class TextToImageResponse(BaseModel):
|
||||
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: dict
|
||||
info: str
|
||||
|
||||
class ImageToImageResponse(BaseModel):
|
||||
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: dict
|
||||
info: str
|
||||
|
||||
class ExtrasBaseRequest(BaseModel):
|
||||
resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
|
||||
show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
|
||||
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
|
||||
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
|
||||
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
|
||||
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
|
||||
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
|
||||
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
|
||||
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
|
||||
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
|
||||
upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?")
|
||||
|
||||
class ExtraBaseResponse(BaseModel):
|
||||
html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
|
||||
|
||||
class ExtrasSingleImageRequest(ExtrasBaseRequest):
|
||||
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||
|
||||
class ExtrasSingleImageResponse(ExtraBaseResponse):
|
||||
image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
|
||||
class FileData(BaseModel):
|
||||
data: str = Field(title="File data", description="Base64 representation of the file")
|
||||
name: str = Field(title="File name")
|
||||
|
||||
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
|
||||
imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
|
||||
|
||||
class ExtrasBatchImagesResponse(ExtraBaseResponse):
|
||||
images: List[str] = Field(title="Images", description="The generated images in base64 format.")
|
||||
|
||||
class PNGInfoRequest(BaseModel):
|
||||
image: str = Field(title="Image", description="The base64 encoded PNG image")
|
||||
|
||||
class PNGInfoResponse(BaseModel):
|
||||
info: str = Field(title="Image info", description="A string with all the info the image had")
|
||||
|
||||
class ProgressRequest(BaseModel):
|
||||
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
|
||||
|
||||
class ProgressResponse(BaseModel):
|
||||
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
|
||||
eta_relative: float = Field(title="ETA in secs")
|
||||
state: dict = Field(title="State", description="The current state snapshot")
|
||||
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
||||
|
||||
class InterrogateRequest(BaseModel):
|
||||
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||
model: str = Field(default="clip", title="Model", description="The interrogate model used.")
|
||||
|
||||
class InterrogateResponse(BaseModel):
|
||||
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
|
||||
|
||||
class TrainResponse(BaseModel):
|
||||
info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.")
|
||||
|
||||
class CreateResponse(BaseModel):
|
||||
info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
|
||||
|
||||
class PreprocessResponse(BaseModel):
|
||||
info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
|
||||
|
||||
fields = {}
|
||||
for key, metadata in opts.data_labels.items():
|
||||
value = opts.data.get(key)
|
||||
optType = opts.typemap.get(type(metadata.default), type(value))
|
||||
|
||||
if (metadata is not None):
|
||||
fields.update({key: (Optional[optType], Field(
|
||||
default=metadata.default ,description=metadata.label))})
|
||||
else:
|
||||
fields.update({key: (Optional[optType], Field())})
|
||||
|
||||
OptionsModel = create_model("Options", **fields)
|
||||
|
||||
flags = {}
|
||||
_options = vars(parser)['_option_string_actions']
|
||||
for key in _options:
|
||||
if(_options[key].dest != 'help'):
|
||||
flag = _options[key]
|
||||
_type = str
|
||||
if _options[key].default is not None: _type = type(_options[key].default)
|
||||
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
|
||||
|
||||
FlagsModel = create_model("Flags", **flags)
|
||||
|
||||
class SamplerItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
aliases: List[str] = Field(title="Aliases")
|
||||
options: Dict[str, str] = Field(title="Options")
|
||||
|
||||
class UpscalerItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
model_name: Optional[str] = Field(title="Model Name")
|
||||
model_path: Optional[str] = Field(title="Path")
|
||||
model_url: Optional[str] = Field(title="URL")
|
||||
|
||||
class SDModelItem(BaseModel):
|
||||
title: str = Field(title="Title")
|
||||
model_name: str = Field(title="Model Name")
|
||||
hash: str = Field(title="Hash")
|
||||
filename: str = Field(title="Filename")
|
||||
config: str = Field(title="Config file")
|
||||
|
||||
class HypernetworkItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
path: Optional[str] = Field(title="Path")
|
||||
|
||||
class FaceRestorerItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
cmd_dir: Optional[str] = Field(title="Path")
|
||||
|
||||
class RealesrganItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
path: Optional[str] = Field(title="Path")
|
||||
scale: Optional[int] = Field(title="Scale")
|
||||
|
||||
class PromptStyleItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
prompt: Optional[str] = Field(title="Prompt")
|
||||
negative_prompt: Optional[str] = Field(title="Negative Prompt")
|
||||
|
||||
class ArtistItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
score: float = Field(title="Score")
|
||||
category: str = Field(title="Category")
|
||||
|
||||
class EmbeddingItem(BaseModel):
|
||||
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
|
||||
sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
|
||||
sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
|
||||
shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
|
||||
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
|
||||
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
|
||||
skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
|
||||
@ -1,102 +0,0 @@
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
|
||||
|
||||
def initialize_weights(net_l, scale=1):
|
||||
if not isinstance(net_l, list):
|
||||
net_l = [net_l]
|
||||
for net in net_l:
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
m.weight.data *= scale # for residual block
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
m.weight.data *= scale
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
|
||||
def make_layer(block, n_layers):
|
||||
layers = []
|
||||
for _ in range(n_layers):
|
||||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
# initialization
|
||||
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
'''Residual in Residual Dense Block'''
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
self.sf = sf
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
if self.sf==4:
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
fea = fea + trunk
|
||||
|
||||
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
if self.sf==4:
|
||||
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
return out
|
||||
@ -0,0 +1,98 @@
|
||||
import html
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
import time
|
||||
|
||||
from modules import shared
|
||||
|
||||
queue_lock = threading.Lock()
|
||||
|
||||
|
||||
def wrap_queued_call(func):
|
||||
def f(*args, **kwargs):
|
||||
with queue_lock:
|
||||
res = func(*args, **kwargs)
|
||||
|
||||
return res
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||
def f(*args, **kwargs):
|
||||
|
||||
shared.state.begin()
|
||||
|
||||
with queue_lock:
|
||||
res = func(*args, **kwargs)
|
||||
|
||||
shared.state.end()
|
||||
|
||||
return res
|
||||
|
||||
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
|
||||
|
||||
|
||||
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
||||
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
||||
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
|
||||
if run_memmon:
|
||||
shared.mem_mon.monitor()
|
||||
t = time.perf_counter()
|
||||
|
||||
try:
|
||||
res = list(func(*args, **kwargs))
|
||||
except Exception as e:
|
||||
# When printing out our debug argument list, do not print out more than a MB of text
|
||||
max_debug_str_len = 131072 # (1024*1024)/8
|
||||
|
||||
print("Error completing request", file=sys.stderr)
|
||||
argStr = f"Arguments: {str(args)} {str(kwargs)}"
|
||||
print(argStr[:max_debug_str_len], file=sys.stderr)
|
||||
if len(argStr) > max_debug_str_len:
|
||||
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
||||
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
shared.state.job = ""
|
||||
shared.state.job_count = 0
|
||||
|
||||
if extra_outputs_array is None:
|
||||
extra_outputs_array = [None, '']
|
||||
|
||||
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
|
||||
|
||||
shared.state.skipped = False
|
||||
shared.state.interrupted = False
|
||||
shared.state.job_count = 0
|
||||
|
||||
if not add_stats:
|
||||
return tuple(res)
|
||||
|
||||
elapsed = time.perf_counter() - t
|
||||
elapsed_m = int(elapsed // 60)
|
||||
elapsed_s = elapsed % 60
|
||||
elapsed_text = f"{elapsed_s:.2f}s"
|
||||
if elapsed_m > 0:
|
||||
elapsed_text = f"{elapsed_m}m "+elapsed_text
|
||||
|
||||
if run_memmon:
|
||||
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
|
||||
active_peak = mem_stats['active_peak']
|
||||
reserved_peak = mem_stats['reserved_peak']
|
||||
sys_peak = mem_stats['system_peak']
|
||||
sys_total = mem_stats['total']
|
||||
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
|
||||
|
||||
vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
|
||||
else:
|
||||
vram_html = ''
|
||||
|
||||
# last item is always HTML
|
||||
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
|
||||
|
||||
return tuple(res)
|
||||
|
||||
return f
|
||||
|
||||
@ -0,0 +1,99 @@
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from modules import modelloader, paths, deepbooru_model, devices, images, shared
|
||||
|
||||
re_special = re.compile(r'([\\()])')
|
||||
|
||||
|
||||
class DeepDanbooru:
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
|
||||
def load(self):
|
||||
if self.model is not None:
|
||||
return
|
||||
|
||||
files = modelloader.load_models(
|
||||
model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
|
||||
model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
|
||||
ext_filter=[".pt"],
|
||||
download_name='model-resnet_custom_v3.pt',
|
||||
)
|
||||
|
||||
self.model = deepbooru_model.DeepDanbooruModel()
|
||||
self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
|
||||
|
||||
self.model.eval()
|
||||
self.model.to(devices.cpu, devices.dtype)
|
||||
|
||||
def start(self):
|
||||
self.load()
|
||||
self.model.to(devices.device)
|
||||
|
||||
def stop(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
self.model.to(devices.cpu)
|
||||
devices.torch_gc()
|
||||
|
||||
def tag(self, pil_image):
|
||||
self.start()
|
||||
res = self.tag_multi(pil_image)
|
||||
self.stop()
|
||||
|
||||
return res
|
||||
|
||||
def tag_multi(self, pil_image, force_disable_ranks=False):
|
||||
threshold = shared.opts.interrogate_deepbooru_score_threshold
|
||||
use_spaces = shared.opts.deepbooru_use_spaces
|
||||
use_escape = shared.opts.deepbooru_escape
|
||||
alpha_sort = shared.opts.deepbooru_sort_alpha
|
||||
include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
|
||||
|
||||
pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
|
||||
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
|
||||
|
||||
with torch.no_grad(), devices.autocast():
|
||||
x = torch.from_numpy(a).to(devices.device)
|
||||
y = self.model(x)[0].detach().cpu().numpy()
|
||||
|
||||
probability_dict = {}
|
||||
|
||||
for tag, probability in zip(self.model.tags, y):
|
||||
if probability < threshold:
|
||||
continue
|
||||
|
||||
if tag.startswith("rating:"):
|
||||
continue
|
||||
|
||||
probability_dict[tag] = probability
|
||||
|
||||
if alpha_sort:
|
||||
tags = sorted(probability_dict)
|
||||
else:
|
||||
tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
|
||||
|
||||
res = []
|
||||
|
||||
filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
|
||||
|
||||
for tag in [x for x in tags if x not in filtertags]:
|
||||
probability = probability_dict[tag]
|
||||
tag_outformat = tag
|
||||
if use_spaces:
|
||||
tag_outformat = tag_outformat.replace('_', ' ')
|
||||
if use_escape:
|
||||
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
|
||||
if include_ranks:
|
||||
tag_outformat = f"({tag_outformat}:{probability:.3f})"
|
||||
|
||||
res.append(tag_outformat)
|
||||
|
||||
return ", ".join(res)
|
||||
|
||||
|
||||
model = DeepDanbooru()
|
||||
@ -0,0 +1,676 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
|
||||
|
||||
|
||||
class DeepDanbooruModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(DeepDanbooruModel, self).__init__()
|
||||
|
||||
self.tags = []
|
||||
|
||||
self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
|
||||
self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
|
||||
self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||
self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
|
||||
self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
||||
self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||
self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
||||
self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
||||
self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||
self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
||||
self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
||||
self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||
self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
|
||||
self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
|
||||
self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
|
||||
self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
|
||||
self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
|
||||
self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
||||
self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
||||
self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
|
||||
self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
|
||||
self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
|
||||
self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
|
||||
self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
||||
self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
||||
self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
||||
self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
||||
self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
||||
self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
||||
self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
||||
self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
|
||||
self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
|
||||
self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
|
||||
self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
||||
self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
||||
self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
||||
self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
||||
self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
||||
self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
||||
self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
||||
self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
|
||||
|
||||
def forward(self, *inputs):
|
||||
t_358, = inputs
|
||||
t_359 = t_358.permute(*[0, 3, 1, 2])
|
||||
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
|
||||
t_360 = self.n_Conv_0(t_359_padded)
|
||||
t_361 = F.relu(t_360)
|
||||
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
|
||||
t_362 = self.n_MaxPool_0(t_361)
|
||||
t_363 = self.n_Conv_1(t_362)
|
||||
t_364 = self.n_Conv_2(t_362)
|
||||
t_365 = F.relu(t_364)
|
||||
t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
|
||||
t_366 = self.n_Conv_3(t_365_padded)
|
||||
t_367 = F.relu(t_366)
|
||||
t_368 = self.n_Conv_4(t_367)
|
||||
t_369 = torch.add(t_368, t_363)
|
||||
t_370 = F.relu(t_369)
|
||||
t_371 = self.n_Conv_5(t_370)
|
||||
t_372 = F.relu(t_371)
|
||||
t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
|
||||
t_373 = self.n_Conv_6(t_372_padded)
|
||||
t_374 = F.relu(t_373)
|
||||
t_375 = self.n_Conv_7(t_374)
|
||||
t_376 = torch.add(t_375, t_370)
|
||||
t_377 = F.relu(t_376)
|
||||
t_378 = self.n_Conv_8(t_377)
|
||||
t_379 = F.relu(t_378)
|
||||
t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
|
||||
t_380 = self.n_Conv_9(t_379_padded)
|
||||
t_381 = F.relu(t_380)
|
||||
t_382 = self.n_Conv_10(t_381)
|
||||
t_383 = torch.add(t_382, t_377)
|
||||
t_384 = F.relu(t_383)
|
||||
t_385 = self.n_Conv_11(t_384)
|
||||
t_386 = self.n_Conv_12(t_384)
|
||||
t_387 = F.relu(t_386)
|
||||
t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
|
||||
t_388 = self.n_Conv_13(t_387_padded)
|
||||
t_389 = F.relu(t_388)
|
||||
t_390 = self.n_Conv_14(t_389)
|
||||
t_391 = torch.add(t_390, t_385)
|
||||
t_392 = F.relu(t_391)
|
||||
t_393 = self.n_Conv_15(t_392)
|
||||
t_394 = F.relu(t_393)
|
||||
t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
|
||||
t_395 = self.n_Conv_16(t_394_padded)
|
||||
t_396 = F.relu(t_395)
|
||||
t_397 = self.n_Conv_17(t_396)
|
||||
t_398 = torch.add(t_397, t_392)
|
||||
t_399 = F.relu(t_398)
|
||||
t_400 = self.n_Conv_18(t_399)
|
||||
t_401 = F.relu(t_400)
|
||||
t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
|
||||
t_402 = self.n_Conv_19(t_401_padded)
|
||||
t_403 = F.relu(t_402)
|
||||
t_404 = self.n_Conv_20(t_403)
|
||||
t_405 = torch.add(t_404, t_399)
|
||||
t_406 = F.relu(t_405)
|
||||
t_407 = self.n_Conv_21(t_406)
|
||||
t_408 = F.relu(t_407)
|
||||
t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
|
||||
t_409 = self.n_Conv_22(t_408_padded)
|
||||
t_410 = F.relu(t_409)
|
||||
t_411 = self.n_Conv_23(t_410)
|
||||
t_412 = torch.add(t_411, t_406)
|
||||
t_413 = F.relu(t_412)
|
||||
t_414 = self.n_Conv_24(t_413)
|
||||
t_415 = F.relu(t_414)
|
||||
t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
|
||||
t_416 = self.n_Conv_25(t_415_padded)
|
||||
t_417 = F.relu(t_416)
|
||||
t_418 = self.n_Conv_26(t_417)
|
||||
t_419 = torch.add(t_418, t_413)
|
||||
t_420 = F.relu(t_419)
|
||||
t_421 = self.n_Conv_27(t_420)
|
||||
t_422 = F.relu(t_421)
|
||||
t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
|
||||
t_423 = self.n_Conv_28(t_422_padded)
|
||||
t_424 = F.relu(t_423)
|
||||
t_425 = self.n_Conv_29(t_424)
|
||||
t_426 = torch.add(t_425, t_420)
|
||||
t_427 = F.relu(t_426)
|
||||
t_428 = self.n_Conv_30(t_427)
|
||||
t_429 = F.relu(t_428)
|
||||
t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
|
||||
t_430 = self.n_Conv_31(t_429_padded)
|
||||
t_431 = F.relu(t_430)
|
||||
t_432 = self.n_Conv_32(t_431)
|
||||
t_433 = torch.add(t_432, t_427)
|
||||
t_434 = F.relu(t_433)
|
||||
t_435 = self.n_Conv_33(t_434)
|
||||
t_436 = F.relu(t_435)
|
||||
t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
|
||||
t_437 = self.n_Conv_34(t_436_padded)
|
||||
t_438 = F.relu(t_437)
|
||||
t_439 = self.n_Conv_35(t_438)
|
||||
t_440 = torch.add(t_439, t_434)
|
||||
t_441 = F.relu(t_440)
|
||||
t_442 = self.n_Conv_36(t_441)
|
||||
t_443 = self.n_Conv_37(t_441)
|
||||
t_444 = F.relu(t_443)
|
||||
t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
|
||||
t_445 = self.n_Conv_38(t_444_padded)
|
||||
t_446 = F.relu(t_445)
|
||||
t_447 = self.n_Conv_39(t_446)
|
||||
t_448 = torch.add(t_447, t_442)
|
||||
t_449 = F.relu(t_448)
|
||||
t_450 = self.n_Conv_40(t_449)
|
||||
t_451 = F.relu(t_450)
|
||||
t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
|
||||
t_452 = self.n_Conv_41(t_451_padded)
|
||||
t_453 = F.relu(t_452)
|
||||
t_454 = self.n_Conv_42(t_453)
|
||||
t_455 = torch.add(t_454, t_449)
|
||||
t_456 = F.relu(t_455)
|
||||
t_457 = self.n_Conv_43(t_456)
|
||||
t_458 = F.relu(t_457)
|
||||
t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
|
||||
t_459 = self.n_Conv_44(t_458_padded)
|
||||
t_460 = F.relu(t_459)
|
||||
t_461 = self.n_Conv_45(t_460)
|
||||
t_462 = torch.add(t_461, t_456)
|
||||
t_463 = F.relu(t_462)
|
||||
t_464 = self.n_Conv_46(t_463)
|
||||
t_465 = F.relu(t_464)
|
||||
t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
|
||||
t_466 = self.n_Conv_47(t_465_padded)
|
||||
t_467 = F.relu(t_466)
|
||||
t_468 = self.n_Conv_48(t_467)
|
||||
t_469 = torch.add(t_468, t_463)
|
||||
t_470 = F.relu(t_469)
|
||||
t_471 = self.n_Conv_49(t_470)
|
||||
t_472 = F.relu(t_471)
|
||||
t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
|
||||
t_473 = self.n_Conv_50(t_472_padded)
|
||||
t_474 = F.relu(t_473)
|
||||
t_475 = self.n_Conv_51(t_474)
|
||||
t_476 = torch.add(t_475, t_470)
|
||||
t_477 = F.relu(t_476)
|
||||
t_478 = self.n_Conv_52(t_477)
|
||||
t_479 = F.relu(t_478)
|
||||
t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
|
||||
t_480 = self.n_Conv_53(t_479_padded)
|
||||
t_481 = F.relu(t_480)
|
||||
t_482 = self.n_Conv_54(t_481)
|
||||
t_483 = torch.add(t_482, t_477)
|
||||
t_484 = F.relu(t_483)
|
||||
t_485 = self.n_Conv_55(t_484)
|
||||
t_486 = F.relu(t_485)
|
||||
t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
|
||||
t_487 = self.n_Conv_56(t_486_padded)
|
||||
t_488 = F.relu(t_487)
|
||||
t_489 = self.n_Conv_57(t_488)
|
||||
t_490 = torch.add(t_489, t_484)
|
||||
t_491 = F.relu(t_490)
|
||||
t_492 = self.n_Conv_58(t_491)
|
||||
t_493 = F.relu(t_492)
|
||||
t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
|
||||
t_494 = self.n_Conv_59(t_493_padded)
|
||||
t_495 = F.relu(t_494)
|
||||
t_496 = self.n_Conv_60(t_495)
|
||||
t_497 = torch.add(t_496, t_491)
|
||||
t_498 = F.relu(t_497)
|
||||
t_499 = self.n_Conv_61(t_498)
|
||||
t_500 = F.relu(t_499)
|
||||
t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
|
||||
t_501 = self.n_Conv_62(t_500_padded)
|
||||
t_502 = F.relu(t_501)
|
||||
t_503 = self.n_Conv_63(t_502)
|
||||
t_504 = torch.add(t_503, t_498)
|
||||
t_505 = F.relu(t_504)
|
||||
t_506 = self.n_Conv_64(t_505)
|
||||
t_507 = F.relu(t_506)
|
||||
t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
|
||||
t_508 = self.n_Conv_65(t_507_padded)
|
||||
t_509 = F.relu(t_508)
|
||||
t_510 = self.n_Conv_66(t_509)
|
||||
t_511 = torch.add(t_510, t_505)
|
||||
t_512 = F.relu(t_511)
|
||||
t_513 = self.n_Conv_67(t_512)
|
||||
t_514 = F.relu(t_513)
|
||||
t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
|
||||
t_515 = self.n_Conv_68(t_514_padded)
|
||||
t_516 = F.relu(t_515)
|
||||
t_517 = self.n_Conv_69(t_516)
|
||||
t_518 = torch.add(t_517, t_512)
|
||||
t_519 = F.relu(t_518)
|
||||
t_520 = self.n_Conv_70(t_519)
|
||||
t_521 = F.relu(t_520)
|
||||
t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
|
||||
t_522 = self.n_Conv_71(t_521_padded)
|
||||
t_523 = F.relu(t_522)
|
||||
t_524 = self.n_Conv_72(t_523)
|
||||
t_525 = torch.add(t_524, t_519)
|
||||
t_526 = F.relu(t_525)
|
||||
t_527 = self.n_Conv_73(t_526)
|
||||
t_528 = F.relu(t_527)
|
||||
t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
|
||||
t_529 = self.n_Conv_74(t_528_padded)
|
||||
t_530 = F.relu(t_529)
|
||||
t_531 = self.n_Conv_75(t_530)
|
||||
t_532 = torch.add(t_531, t_526)
|
||||
t_533 = F.relu(t_532)
|
||||
t_534 = self.n_Conv_76(t_533)
|
||||
t_535 = F.relu(t_534)
|
||||
t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
|
||||
t_536 = self.n_Conv_77(t_535_padded)
|
||||
t_537 = F.relu(t_536)
|
||||
t_538 = self.n_Conv_78(t_537)
|
||||
t_539 = torch.add(t_538, t_533)
|
||||
t_540 = F.relu(t_539)
|
||||
t_541 = self.n_Conv_79(t_540)
|
||||
t_542 = F.relu(t_541)
|
||||
t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
|
||||
t_543 = self.n_Conv_80(t_542_padded)
|
||||
t_544 = F.relu(t_543)
|
||||
t_545 = self.n_Conv_81(t_544)
|
||||
t_546 = torch.add(t_545, t_540)
|
||||
t_547 = F.relu(t_546)
|
||||
t_548 = self.n_Conv_82(t_547)
|
||||
t_549 = F.relu(t_548)
|
||||
t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
|
||||
t_550 = self.n_Conv_83(t_549_padded)
|
||||
t_551 = F.relu(t_550)
|
||||
t_552 = self.n_Conv_84(t_551)
|
||||
t_553 = torch.add(t_552, t_547)
|
||||
t_554 = F.relu(t_553)
|
||||
t_555 = self.n_Conv_85(t_554)
|
||||
t_556 = F.relu(t_555)
|
||||
t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
|
||||
t_557 = self.n_Conv_86(t_556_padded)
|
||||
t_558 = F.relu(t_557)
|
||||
t_559 = self.n_Conv_87(t_558)
|
||||
t_560 = torch.add(t_559, t_554)
|
||||
t_561 = F.relu(t_560)
|
||||
t_562 = self.n_Conv_88(t_561)
|
||||
t_563 = F.relu(t_562)
|
||||
t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
|
||||
t_564 = self.n_Conv_89(t_563_padded)
|
||||
t_565 = F.relu(t_564)
|
||||
t_566 = self.n_Conv_90(t_565)
|
||||
t_567 = torch.add(t_566, t_561)
|
||||
t_568 = F.relu(t_567)
|
||||
t_569 = self.n_Conv_91(t_568)
|
||||
t_570 = F.relu(t_569)
|
||||
t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
|
||||
t_571 = self.n_Conv_92(t_570_padded)
|
||||
t_572 = F.relu(t_571)
|
||||
t_573 = self.n_Conv_93(t_572)
|
||||
t_574 = torch.add(t_573, t_568)
|
||||
t_575 = F.relu(t_574)
|
||||
t_576 = self.n_Conv_94(t_575)
|
||||
t_577 = F.relu(t_576)
|
||||
t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
|
||||
t_578 = self.n_Conv_95(t_577_padded)
|
||||
t_579 = F.relu(t_578)
|
||||
t_580 = self.n_Conv_96(t_579)
|
||||
t_581 = torch.add(t_580, t_575)
|
||||
t_582 = F.relu(t_581)
|
||||
t_583 = self.n_Conv_97(t_582)
|
||||
t_584 = F.relu(t_583)
|
||||
t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
|
||||
t_585 = self.n_Conv_98(t_584_padded)
|
||||
t_586 = F.relu(t_585)
|
||||
t_587 = self.n_Conv_99(t_586)
|
||||
t_588 = self.n_Conv_100(t_582)
|
||||
t_589 = torch.add(t_587, t_588)
|
||||
t_590 = F.relu(t_589)
|
||||
t_591 = self.n_Conv_101(t_590)
|
||||
t_592 = F.relu(t_591)
|
||||
t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
|
||||
t_593 = self.n_Conv_102(t_592_padded)
|
||||
t_594 = F.relu(t_593)
|
||||
t_595 = self.n_Conv_103(t_594)
|
||||
t_596 = torch.add(t_595, t_590)
|
||||
t_597 = F.relu(t_596)
|
||||
t_598 = self.n_Conv_104(t_597)
|
||||
t_599 = F.relu(t_598)
|
||||
t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
|
||||
t_600 = self.n_Conv_105(t_599_padded)
|
||||
t_601 = F.relu(t_600)
|
||||
t_602 = self.n_Conv_106(t_601)
|
||||
t_603 = torch.add(t_602, t_597)
|
||||
t_604 = F.relu(t_603)
|
||||
t_605 = self.n_Conv_107(t_604)
|
||||
t_606 = F.relu(t_605)
|
||||
t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
|
||||
t_607 = self.n_Conv_108(t_606_padded)
|
||||
t_608 = F.relu(t_607)
|
||||
t_609 = self.n_Conv_109(t_608)
|
||||
t_610 = torch.add(t_609, t_604)
|
||||
t_611 = F.relu(t_610)
|
||||
t_612 = self.n_Conv_110(t_611)
|
||||
t_613 = F.relu(t_612)
|
||||
t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
|
||||
t_614 = self.n_Conv_111(t_613_padded)
|
||||
t_615 = F.relu(t_614)
|
||||
t_616 = self.n_Conv_112(t_615)
|
||||
t_617 = torch.add(t_616, t_611)
|
||||
t_618 = F.relu(t_617)
|
||||
t_619 = self.n_Conv_113(t_618)
|
||||
t_620 = F.relu(t_619)
|
||||
t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
|
||||
t_621 = self.n_Conv_114(t_620_padded)
|
||||
t_622 = F.relu(t_621)
|
||||
t_623 = self.n_Conv_115(t_622)
|
||||
t_624 = torch.add(t_623, t_618)
|
||||
t_625 = F.relu(t_624)
|
||||
t_626 = self.n_Conv_116(t_625)
|
||||
t_627 = F.relu(t_626)
|
||||
t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
|
||||
t_628 = self.n_Conv_117(t_627_padded)
|
||||
t_629 = F.relu(t_628)
|
||||
t_630 = self.n_Conv_118(t_629)
|
||||
t_631 = torch.add(t_630, t_625)
|
||||
t_632 = F.relu(t_631)
|
||||
t_633 = self.n_Conv_119(t_632)
|
||||
t_634 = F.relu(t_633)
|
||||
t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
|
||||
t_635 = self.n_Conv_120(t_634_padded)
|
||||
t_636 = F.relu(t_635)
|
||||
t_637 = self.n_Conv_121(t_636)
|
||||
t_638 = torch.add(t_637, t_632)
|
||||
t_639 = F.relu(t_638)
|
||||
t_640 = self.n_Conv_122(t_639)
|
||||
t_641 = F.relu(t_640)
|
||||
t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
|
||||
t_642 = self.n_Conv_123(t_641_padded)
|
||||
t_643 = F.relu(t_642)
|
||||
t_644 = self.n_Conv_124(t_643)
|
||||
t_645 = torch.add(t_644, t_639)
|
||||
t_646 = F.relu(t_645)
|
||||
t_647 = self.n_Conv_125(t_646)
|
||||
t_648 = F.relu(t_647)
|
||||
t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
|
||||
t_649 = self.n_Conv_126(t_648_padded)
|
||||
t_650 = F.relu(t_649)
|
||||
t_651 = self.n_Conv_127(t_650)
|
||||
t_652 = torch.add(t_651, t_646)
|
||||
t_653 = F.relu(t_652)
|
||||
t_654 = self.n_Conv_128(t_653)
|
||||
t_655 = F.relu(t_654)
|
||||
t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
|
||||
t_656 = self.n_Conv_129(t_655_padded)
|
||||
t_657 = F.relu(t_656)
|
||||
t_658 = self.n_Conv_130(t_657)
|
||||
t_659 = torch.add(t_658, t_653)
|
||||
t_660 = F.relu(t_659)
|
||||
t_661 = self.n_Conv_131(t_660)
|
||||
t_662 = F.relu(t_661)
|
||||
t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
|
||||
t_663 = self.n_Conv_132(t_662_padded)
|
||||
t_664 = F.relu(t_663)
|
||||
t_665 = self.n_Conv_133(t_664)
|
||||
t_666 = torch.add(t_665, t_660)
|
||||
t_667 = F.relu(t_666)
|
||||
t_668 = self.n_Conv_134(t_667)
|
||||
t_669 = F.relu(t_668)
|
||||
t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
|
||||
t_670 = self.n_Conv_135(t_669_padded)
|
||||
t_671 = F.relu(t_670)
|
||||
t_672 = self.n_Conv_136(t_671)
|
||||
t_673 = torch.add(t_672, t_667)
|
||||
t_674 = F.relu(t_673)
|
||||
t_675 = self.n_Conv_137(t_674)
|
||||
t_676 = F.relu(t_675)
|
||||
t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
|
||||
t_677 = self.n_Conv_138(t_676_padded)
|
||||
t_678 = F.relu(t_677)
|
||||
t_679 = self.n_Conv_139(t_678)
|
||||
t_680 = torch.add(t_679, t_674)
|
||||
t_681 = F.relu(t_680)
|
||||
t_682 = self.n_Conv_140(t_681)
|
||||
t_683 = F.relu(t_682)
|
||||
t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
|
||||
t_684 = self.n_Conv_141(t_683_padded)
|
||||
t_685 = F.relu(t_684)
|
||||
t_686 = self.n_Conv_142(t_685)
|
||||
t_687 = torch.add(t_686, t_681)
|
||||
t_688 = F.relu(t_687)
|
||||
t_689 = self.n_Conv_143(t_688)
|
||||
t_690 = F.relu(t_689)
|
||||
t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
|
||||
t_691 = self.n_Conv_144(t_690_padded)
|
||||
t_692 = F.relu(t_691)
|
||||
t_693 = self.n_Conv_145(t_692)
|
||||
t_694 = torch.add(t_693, t_688)
|
||||
t_695 = F.relu(t_694)
|
||||
t_696 = self.n_Conv_146(t_695)
|
||||
t_697 = F.relu(t_696)
|
||||
t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
|
||||
t_698 = self.n_Conv_147(t_697_padded)
|
||||
t_699 = F.relu(t_698)
|
||||
t_700 = self.n_Conv_148(t_699)
|
||||
t_701 = torch.add(t_700, t_695)
|
||||
t_702 = F.relu(t_701)
|
||||
t_703 = self.n_Conv_149(t_702)
|
||||
t_704 = F.relu(t_703)
|
||||
t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
|
||||
t_705 = self.n_Conv_150(t_704_padded)
|
||||
t_706 = F.relu(t_705)
|
||||
t_707 = self.n_Conv_151(t_706)
|
||||
t_708 = torch.add(t_707, t_702)
|
||||
t_709 = F.relu(t_708)
|
||||
t_710 = self.n_Conv_152(t_709)
|
||||
t_711 = F.relu(t_710)
|
||||
t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
|
||||
t_712 = self.n_Conv_153(t_711_padded)
|
||||
t_713 = F.relu(t_712)
|
||||
t_714 = self.n_Conv_154(t_713)
|
||||
t_715 = torch.add(t_714, t_709)
|
||||
t_716 = F.relu(t_715)
|
||||
t_717 = self.n_Conv_155(t_716)
|
||||
t_718 = F.relu(t_717)
|
||||
t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
|
||||
t_719 = self.n_Conv_156(t_718_padded)
|
||||
t_720 = F.relu(t_719)
|
||||
t_721 = self.n_Conv_157(t_720)
|
||||
t_722 = torch.add(t_721, t_716)
|
||||
t_723 = F.relu(t_722)
|
||||
t_724 = self.n_Conv_158(t_723)
|
||||
t_725 = self.n_Conv_159(t_723)
|
||||
t_726 = F.relu(t_725)
|
||||
t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
|
||||
t_727 = self.n_Conv_160(t_726_padded)
|
||||
t_728 = F.relu(t_727)
|
||||
t_729 = self.n_Conv_161(t_728)
|
||||
t_730 = torch.add(t_729, t_724)
|
||||
t_731 = F.relu(t_730)
|
||||
t_732 = self.n_Conv_162(t_731)
|
||||
t_733 = F.relu(t_732)
|
||||
t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
|
||||
t_734 = self.n_Conv_163(t_733_padded)
|
||||
t_735 = F.relu(t_734)
|
||||
t_736 = self.n_Conv_164(t_735)
|
||||
t_737 = torch.add(t_736, t_731)
|
||||
t_738 = F.relu(t_737)
|
||||
t_739 = self.n_Conv_165(t_738)
|
||||
t_740 = F.relu(t_739)
|
||||
t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
|
||||
t_741 = self.n_Conv_166(t_740_padded)
|
||||
t_742 = F.relu(t_741)
|
||||
t_743 = self.n_Conv_167(t_742)
|
||||
t_744 = torch.add(t_743, t_738)
|
||||
t_745 = F.relu(t_744)
|
||||
t_746 = self.n_Conv_168(t_745)
|
||||
t_747 = self.n_Conv_169(t_745)
|
||||
t_748 = F.relu(t_747)
|
||||
t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
|
||||
t_749 = self.n_Conv_170(t_748_padded)
|
||||
t_750 = F.relu(t_749)
|
||||
t_751 = self.n_Conv_171(t_750)
|
||||
t_752 = torch.add(t_751, t_746)
|
||||
t_753 = F.relu(t_752)
|
||||
t_754 = self.n_Conv_172(t_753)
|
||||
t_755 = F.relu(t_754)
|
||||
t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
|
||||
t_756 = self.n_Conv_173(t_755_padded)
|
||||
t_757 = F.relu(t_756)
|
||||
t_758 = self.n_Conv_174(t_757)
|
||||
t_759 = torch.add(t_758, t_753)
|
||||
t_760 = F.relu(t_759)
|
||||
t_761 = self.n_Conv_175(t_760)
|
||||
t_762 = F.relu(t_761)
|
||||
t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
|
||||
t_763 = self.n_Conv_176(t_762_padded)
|
||||
t_764 = F.relu(t_763)
|
||||
t_765 = self.n_Conv_177(t_764)
|
||||
t_766 = torch.add(t_765, t_760)
|
||||
t_767 = F.relu(t_766)
|
||||
t_768 = self.n_Conv_178(t_767)
|
||||
t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
|
||||
t_770 = torch.squeeze(t_769, 3)
|
||||
t_770 = torch.squeeze(t_770, 2)
|
||||
t_771 = torch.sigmoid(t_770)
|
||||
return t_771
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
self.tags = state_dict.get('tags', [])
|
||||
|
||||
super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
|
||||
|
||||
@ -1,60 +1,140 @@
|
||||
import sys, os, shlex
|
||||
import contextlib
|
||||
import torch
|
||||
|
||||
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
|
||||
from modules import errors
|
||||
from packaging import version
|
||||
|
||||
has_mps = getattr(torch, 'has_mps', False)
|
||||
|
||||
cpu = torch.device("cpu")
|
||||
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
|
||||
# check `getattr` and try it for compatibility
|
||||
def has_mps() -> bool:
|
||||
if not getattr(torch, 'has_mps', False):
|
||||
return False
|
||||
try:
|
||||
torch.zeros(1).to(torch.device("mps"))
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def extract_device_id(args, name):
|
||||
for x in range(len(args)):
|
||||
if name in args[x]:
|
||||
return args[x + 1]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_cuda_device_string():
|
||||
from modules import shared
|
||||
|
||||
if shared.cmd_opts.device_id is not None:
|
||||
return f"cuda:{shared.cmd_opts.device_id}"
|
||||
|
||||
return "cuda"
|
||||
|
||||
|
||||
def get_optimal_device():
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
return torch.device(get_cuda_device_string())
|
||||
|
||||
if has_mps:
|
||||
if has_mps():
|
||||
return torch.device("mps")
|
||||
|
||||
return cpu
|
||||
|
||||
|
||||
def get_device_for(task):
|
||||
from modules import shared
|
||||
|
||||
if task in shared.cmd_opts.use_cpu:
|
||||
return cpu
|
||||
|
||||
return get_optimal_device()
|
||||
|
||||
|
||||
def torch_gc():
|
||||
if torch.cuda.is_available():
|
||||
with torch.cuda.device(get_cuda_device_string()):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
def enable_tf32():
|
||||
if torch.cuda.is_available():
|
||||
|
||||
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
||||
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
||||
if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
|
||||
errors.run(enable_tf32, "Enabling TF32")
|
||||
|
||||
errors.run(enable_tf32, "Enabling TF32")
|
||||
|
||||
device = get_optimal_device()
|
||||
device_codeformer = cpu if has_mps else device
|
||||
cpu = torch.device("cpu")
|
||||
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
|
||||
dtype = torch.float16
|
||||
dtype_vae = torch.float16
|
||||
|
||||
|
||||
def randn(seed, shape):
|
||||
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
||||
if device.type == 'mps':
|
||||
generator = torch.Generator(device=cpu)
|
||||
generator.manual_seed(seed)
|
||||
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
|
||||
return noise
|
||||
|
||||
torch.manual_seed(seed)
|
||||
if device.type == 'mps':
|
||||
return torch.randn(shape, device=cpu).to(device)
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
def randn_without_seed(shape):
|
||||
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
||||
if device.type == 'mps':
|
||||
generator = torch.Generator(device=cpu)
|
||||
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
|
||||
return noise
|
||||
|
||||
return torch.randn(shape, device=cpu).to(device)
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
def autocast(disable=False):
|
||||
from modules import shared
|
||||
|
||||
if disable:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
||||
return contextlib.nullcontext()
|
||||
|
||||
return torch.autocast("cuda")
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||
orig_tensor_to = torch.Tensor.to
|
||||
def tensor_to_fix(self, *args, **kwargs):
|
||||
if self.device.type != 'mps' and \
|
||||
((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \
|
||||
(isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')):
|
||||
self = self.contiguous()
|
||||
return orig_tensor_to(self, *args, **kwargs)
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||
orig_layer_norm = torch.nn.functional.layer_norm
|
||||
def layer_norm_fix(*args, **kwargs):
|
||||
if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps':
|
||||
args = list(args)
|
||||
args[0] = args[0].contiguous()
|
||||
return orig_layer_norm(*args, **kwargs)
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
||||
orig_tensor_numpy = torch.Tensor.numpy
|
||||
def numpy_fix(self, *args, **kwargs):
|
||||
if self.requires_grad:
|
||||
self = self.detach()
|
||||
return orig_tensor_numpy(self, *args, **kwargs)
|
||||
|
||||
|
||||
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
||||
if has_mps() and version.parse(torch.__version__) < version.parse("1.13"):
|
||||
torch.Tensor.to = tensor_to_fix
|
||||
torch.nn.functional.layer_norm = layer_norm_fix
|
||||
torch.Tensor.numpy = numpy_fix
|
||||
|
||||
@ -1,80 +0,0 @@
|
||||
# this file is taken from https://github.com/xinntao/ESRGAN
|
||||
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def make_layer(block, n_layers):
|
||||
layers = []
|
||||
for _ in range(n_layers):
|
||||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
# initialization
|
||||
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
'''Residual in Residual Dense Block'''
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
fea = fea + trunk
|
||||
|
||||
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
return out
|
||||
@ -0,0 +1,463 @@
|
||||
# this file is adapted from https://github.com/victorca25/iNNfer
|
||||
|
||||
import math
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
####################
|
||||
# RRDBNet Generator
|
||||
####################
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
|
||||
act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
|
||||
finalact=None, gaussian_noise=False, plus=False):
|
||||
super(RRDBNet, self).__init__()
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
if upscale == 3:
|
||||
n_upscale = 1
|
||||
|
||||
self.resrgan_scale = 0
|
||||
if in_nc % 16 == 0:
|
||||
self.resrgan_scale = 1
|
||||
elif in_nc != 4 and in_nc % 4 == 0:
|
||||
self.resrgan_scale = 2
|
||||
|
||||
fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
||||
rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||
norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
|
||||
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
|
||||
LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
|
||||
|
||||
if upsample_mode == 'upconv':
|
||||
upsample_block = upconv_block
|
||||
elif upsample_mode == 'pixelshuffle':
|
||||
upsample_block = pixelshuffle_block
|
||||
else:
|
||||
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
||||
if upscale == 3:
|
||||
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
||||
else:
|
||||
upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
|
||||
HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
|
||||
HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
||||
|
||||
outact = act(finalact) if finalact else None
|
||||
|
||||
self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
|
||||
*upsampler, HR_conv0, HR_conv1, outact)
|
||||
|
||||
def forward(self, x, outm=None):
|
||||
if self.resrgan_scale == 1:
|
||||
feat = pixel_unshuffle(x, scale=4)
|
||||
elif self.resrgan_scale == 2:
|
||||
feat = pixel_unshuffle(x, scale=2)
|
||||
else:
|
||||
feat = x
|
||||
|
||||
return self.model(feat)
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
"""
|
||||
Residual in Residual Dense Block
|
||||
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
||||
"""
|
||||
|
||||
def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False, gaussian_noise=False, plus=False):
|
||||
super(RRDB, self).__init__()
|
||||
# This is for backwards compatibility with existing models
|
||||
if nr == 3:
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus)
|
||||
else:
|
||||
RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
|
||||
self.RDBs = nn.Sequential(*RDB_list)
|
||||
|
||||
def forward(self, x):
|
||||
if hasattr(self, 'RDB1'):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
else:
|
||||
out = self.RDBs(x)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
"""
|
||||
Residual Dense Block
|
||||
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
||||
Modified options that can be used:
|
||||
- "Partial Convolution based Padding" arXiv:1811.11718
|
||||
- "Spectral normalization" arXiv:1802.05957
|
||||
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
||||
{Rakotonirina} and A. {Rasoanaivo}
|
||||
"""
|
||||
|
||||
def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False, gaussian_noise=False, plus=False):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
|
||||
self.noise = GaussianNoise() if gaussian_noise else None
|
||||
self.conv1x1 = conv1x1(nf, gc) if plus else None
|
||||
|
||||
self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
if mode == 'CNA':
|
||||
last_act = None
|
||||
else:
|
||||
last_act = act_type
|
||||
self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.conv1(x)
|
||||
x2 = self.conv2(torch.cat((x, x1), 1))
|
||||
if self.conv1x1:
|
||||
x2 = x2 + self.conv1x1(x)
|
||||
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
||||
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
||||
if self.conv1x1:
|
||||
x4 = x4 + x2
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
if self.noise:
|
||||
return self.noise(x5.mul(0.2) + x)
|
||||
else:
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
####################
|
||||
# ESRGANplus
|
||||
####################
|
||||
|
||||
class GaussianNoise(nn.Module):
|
||||
def __init__(self, sigma=0.1, is_relative_detach=False):
|
||||
super().__init__()
|
||||
self.sigma = sigma
|
||||
self.is_relative_detach = is_relative_detach
|
||||
self.noise = torch.tensor(0, dtype=torch.float)
|
||||
|
||||
def forward(self, x):
|
||||
if self.training and self.sigma != 0:
|
||||
self.noise = self.noise.to(x.device)
|
||||
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
||||
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
||||
x = x + sampled_noise
|
||||
return x
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
####################
|
||||
# SRVGGNetCompact
|
||||
####################
|
||||
|
||||
class SRVGGNetCompact(nn.Module):
|
||||
"""A compact VGG-style network structure for super-resolution.
|
||||
This class is copied from https://github.com/xinntao/Real-ESRGAN
|
||||
"""
|
||||
|
||||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
||||
super(SRVGGNetCompact, self).__init__()
|
||||
self.num_in_ch = num_in_ch
|
||||
self.num_out_ch = num_out_ch
|
||||
self.num_feat = num_feat
|
||||
self.num_conv = num_conv
|
||||
self.upscale = upscale
|
||||
self.act_type = act_type
|
||||
|
||||
self.body = nn.ModuleList()
|
||||
# the first conv
|
||||
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
||||
# the first activation
|
||||
if act_type == 'relu':
|
||||
activation = nn.ReLU(inplace=True)
|
||||
elif act_type == 'prelu':
|
||||
activation = nn.PReLU(num_parameters=num_feat)
|
||||
elif act_type == 'leakyrelu':
|
||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.body.append(activation)
|
||||
|
||||
# the body structure
|
||||
for _ in range(num_conv):
|
||||
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
||||
# activation
|
||||
if act_type == 'relu':
|
||||
activation = nn.ReLU(inplace=True)
|
||||
elif act_type == 'prelu':
|
||||
activation = nn.PReLU(num_parameters=num_feat)
|
||||
elif act_type == 'leakyrelu':
|
||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.body.append(activation)
|
||||
|
||||
# the last conv
|
||||
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
||||
# upsample
|
||||
self.upsampler = nn.PixelShuffle(upscale)
|
||||
|
||||
def forward(self, x):
|
||||
out = x
|
||||
for i in range(0, len(self.body)):
|
||||
out = self.body[i](out)
|
||||
|
||||
out = self.upsampler(out)
|
||||
# add the nearest upsampled image, so that the network learns the residual
|
||||
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
||||
out += base
|
||||
return out
|
||||
|
||||
|
||||
####################
|
||||
# Upsampler
|
||||
####################
|
||||
|
||||
class Upsample(nn.Module):
|
||||
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
|
||||
The input data is assumed to be of the form
|
||||
`minibatch x channels x [optional depth] x [optional height] x width`.
|
||||
"""
|
||||
|
||||
def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||
super(Upsample, self).__init__()
|
||||
if isinstance(scale_factor, tuple):
|
||||
self.scale_factor = tuple(float(factor) for factor in scale_factor)
|
||||
else:
|
||||
self.scale_factor = float(scale_factor) if scale_factor else None
|
||||
self.mode = mode
|
||||
self.size = size
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
||||
|
||||
def extra_repr(self):
|
||||
if self.scale_factor is not None:
|
||||
info = 'scale_factor=' + str(self.scale_factor)
|
||||
else:
|
||||
info = 'size=' + str(self.size)
|
||||
info += ', mode=' + self.mode
|
||||
return info
|
||||
|
||||
|
||||
def pixel_unshuffle(x, scale):
|
||||
""" Pixel unshuffle.
|
||||
Args:
|
||||
x (Tensor): Input feature with shape (b, c, hh, hw).
|
||||
scale (int): Downsample ratio.
|
||||
Returns:
|
||||
Tensor: the pixel unshuffled feature.
|
||||
"""
|
||||
b, c, hh, hw = x.size()
|
||||
out_channel = c * (scale**2)
|
||||
assert hh % scale == 0 and hw % scale == 0
|
||||
h = hh // scale
|
||||
w = hw // scale
|
||||
x_view = x.view(b, c, h, scale, w, scale)
|
||||
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
||||
|
||||
|
||||
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
||||
pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
|
||||
"""
|
||||
Pixel shuffle layer
|
||||
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
||||
Neural Network, CVPR17)
|
||||
"""
|
||||
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
|
||||
pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
|
||||
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
a = act(act_type) if act_type else None
|
||||
return sequential(conv, pixel_shuffle, n, a)
|
||||
|
||||
|
||||
def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
||||
pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
|
||||
""" Upconv layer """
|
||||
upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
|
||||
upsample = Upsample(scale_factor=upscale_factor, mode=mode)
|
||||
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
|
||||
pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
|
||||
return sequential(upsample, conv)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
####################
|
||||
# Basic blocks
|
||||
####################
|
||||
|
||||
|
||||
def make_layer(basic_block, num_basic_block, **kwarg):
|
||||
"""Make layers by stacking the same blocks.
|
||||
Args:
|
||||
basic_block (nn.module): nn.module class for basic block. (block)
|
||||
num_basic_block (int): number of blocks. (n_layers)
|
||||
Returns:
|
||||
nn.Sequential: Stacked blocks in nn.Sequential.
|
||||
"""
|
||||
layers = []
|
||||
for _ in range(num_basic_block):
|
||||
layers.append(basic_block(**kwarg))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
|
||||
""" activation helper """
|
||||
act_type = act_type.lower()
|
||||
if act_type == 'relu':
|
||||
layer = nn.ReLU(inplace)
|
||||
elif act_type in ('leakyrelu', 'lrelu'):
|
||||
layer = nn.LeakyReLU(neg_slope, inplace)
|
||||
elif act_type == 'prelu':
|
||||
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
||||
elif act_type == 'tanh': # [-1, 1] range output
|
||||
layer = nn.Tanh()
|
||||
elif act_type == 'sigmoid': # [0, 1] range output
|
||||
layer = nn.Sigmoid()
|
||||
else:
|
||||
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
|
||||
return layer
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self, *kwargs):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x, *kwargs):
|
||||
return x
|
||||
|
||||
|
||||
def norm(norm_type, nc):
|
||||
""" Return a normalization layer """
|
||||
norm_type = norm_type.lower()
|
||||
if norm_type == 'batch':
|
||||
layer = nn.BatchNorm2d(nc, affine=True)
|
||||
elif norm_type == 'instance':
|
||||
layer = nn.InstanceNorm2d(nc, affine=False)
|
||||
elif norm_type == 'none':
|
||||
def norm_layer(x): return Identity()
|
||||
else:
|
||||
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
|
||||
return layer
|
||||
|
||||
|
||||
def pad(pad_type, padding):
|
||||
""" padding layer helper """
|
||||
pad_type = pad_type.lower()
|
||||
if padding == 0:
|
||||
return None
|
||||
if pad_type == 'reflect':
|
||||
layer = nn.ReflectionPad2d(padding)
|
||||
elif pad_type == 'replicate':
|
||||
layer = nn.ReplicationPad2d(padding)
|
||||
elif pad_type == 'zero':
|
||||
layer = nn.ZeroPad2d(padding)
|
||||
else:
|
||||
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
|
||||
return layer
|
||||
|
||||
|
||||
def get_valid_padding(kernel_size, dilation):
|
||||
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
||||
padding = (kernel_size - 1) // 2
|
||||
return padding
|
||||
|
||||
|
||||
class ShortcutBlock(nn.Module):
|
||||
""" Elementwise sum the output of a submodule to its input """
|
||||
def __init__(self, submodule):
|
||||
super(ShortcutBlock, self).__init__()
|
||||
self.sub = submodule
|
||||
|
||||
def forward(self, x):
|
||||
output = x + self.sub(x)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
|
||||
|
||||
|
||||
def sequential(*args):
|
||||
""" Flatten Sequential. It unwraps nn.Sequential. """
|
||||
if len(args) == 1:
|
||||
if isinstance(args[0], OrderedDict):
|
||||
raise NotImplementedError('sequential does not support OrderedDict input.')
|
||||
return args[0] # No sequential is needed.
|
||||
modules = []
|
||||
for module in args:
|
||||
if isinstance(module, nn.Sequential):
|
||||
for submodule in module.children():
|
||||
modules.append(submodule)
|
||||
elif isinstance(module, nn.Module):
|
||||
modules.append(module)
|
||||
return nn.Sequential(*modules)
|
||||
|
||||
|
||||
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
|
||||
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False):
|
||||
""" Conv layer with padding, normalization, activation """
|
||||
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
|
||||
padding = get_valid_padding(kernel_size, dilation)
|
||||
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
||||
padding = padding if pad_type == 'zero' else 0
|
||||
|
||||
if convtype=='PartialConv2D':
|
||||
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
elif convtype=='DeformConv2D':
|
||||
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
elif convtype=='Conv3D':
|
||||
c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
else:
|
||||
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
|
||||
if spectral_norm:
|
||||
c = nn.utils.spectral_norm(c)
|
||||
|
||||
a = act(act_type) if act_type else None
|
||||
if 'CNA' in mode:
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
return sequential(p, c, n, a)
|
||||
elif mode == 'NAC':
|
||||
if norm_type is None and act_type is not None:
|
||||
a = act(act_type, inplace=False)
|
||||
n = norm(norm_type, in_nc) if norm_type else None
|
||||
return sequential(n, a, p, c)
|
||||
@ -0,0 +1,99 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import git
|
||||
|
||||
from modules import paths, shared
|
||||
|
||||
extensions = []
|
||||
extensions_dir = os.path.join(paths.script_path, "extensions")
|
||||
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
|
||||
|
||||
|
||||
def active():
|
||||
return [x for x in extensions if x.enabled]
|
||||
|
||||
|
||||
class Extension:
|
||||
def __init__(self, name, path, enabled=True, is_builtin=False):
|
||||
self.name = name
|
||||
self.path = path
|
||||
self.enabled = enabled
|
||||
self.status = ''
|
||||
self.can_update = False
|
||||
self.is_builtin = is_builtin
|
||||
|
||||
repo = None
|
||||
try:
|
||||
if os.path.exists(os.path.join(path, ".git")):
|
||||
repo = git.Repo(path)
|
||||
except Exception:
|
||||
print(f"Error reading github repository info from {path}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if repo is None or repo.bare:
|
||||
self.remote = None
|
||||
else:
|
||||
try:
|
||||
self.remote = next(repo.remote().urls, None)
|
||||
self.status = 'unknown'
|
||||
except Exception:
|
||||
self.remote = None
|
||||
|
||||
def list_files(self, subdir, extension):
|
||||
from modules import scripts
|
||||
|
||||
dirpath = os.path.join(self.path, subdir)
|
||||
if not os.path.isdir(dirpath):
|
||||
return []
|
||||
|
||||
res = []
|
||||
for filename in sorted(os.listdir(dirpath)):
|
||||
res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
|
||||
|
||||
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||
|
||||
return res
|
||||
|
||||
def check_updates(self):
|
||||
repo = git.Repo(self.path)
|
||||
for fetch in repo.remote().fetch("--dry-run"):
|
||||
if fetch.flags != fetch.HEAD_UPTODATE:
|
||||
self.can_update = True
|
||||
self.status = "behind"
|
||||
return
|
||||
|
||||
self.can_update = False
|
||||
self.status = "latest"
|
||||
|
||||
def fetch_and_reset_hard(self):
|
||||
repo = git.Repo(self.path)
|
||||
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
||||
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
||||
repo.git.fetch('--all')
|
||||
repo.git.reset('--hard', 'origin')
|
||||
|
||||
|
||||
def list_extensions():
|
||||
extensions.clear()
|
||||
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
paths = []
|
||||
for dirname in [extensions_dir, extensions_builtin_dir]:
|
||||
if not os.path.isdir(dirname):
|
||||
return
|
||||
|
||||
for extension_dirname in sorted(os.listdir(dirname)):
|
||||
path = os.path.join(dirname, extension_dirname)
|
||||
if not os.path.isdir(path):
|
||||
continue
|
||||
|
||||
paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
|
||||
|
||||
for dirname, path, is_builtin in paths:
|
||||
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
|
||||
extensions.append(extension)
|
||||
|
||||
@ -0,0 +1,667 @@
|
||||
import csv
|
||||
import datetime
|
||||
import glob
|
||||
import html
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import inspect
|
||||
|
||||
import modules.textual_inversion.dataset
|
||||
import torch
|
||||
import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from ldm.util import default
|
||||
from modules import devices, processing, sd_models, shared, sd_samplers
|
||||
from modules.textual_inversion import textual_inversion
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
from torch import einsum
|
||||
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from statistics import stdev, mean
|
||||
|
||||
|
||||
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
multiplier = 1.0
|
||||
activation_dict = {
|
||||
"linear": torch.nn.Identity,
|
||||
"relu": torch.nn.ReLU,
|
||||
"leakyrelu": torch.nn.LeakyReLU,
|
||||
"elu": torch.nn.ELU,
|
||||
"swish": torch.nn.Hardswish,
|
||||
"tanh": torch.nn.Tanh,
|
||||
"sigmoid": torch.nn.Sigmoid,
|
||||
}
|
||||
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
||||
|
||||
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
|
||||
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
|
||||
super().__init__()
|
||||
|
||||
assert layer_structure is not None, "layer_structure must not be None"
|
||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||
|
||||
linears = []
|
||||
for i in range(len(layer_structure) - 1):
|
||||
|
||||
# Add a fully-connected layer
|
||||
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||
|
||||
# Add an activation func except last layer
|
||||
if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
|
||||
pass
|
||||
elif activation_func in self.activation_dict:
|
||||
linears.append(self.activation_dict[activation_func]())
|
||||
else:
|
||||
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
||||
|
||||
# Add layer normalization
|
||||
if add_layer_norm:
|
||||
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||
|
||||
# Add dropout except last layer
|
||||
if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
|
||||
linears.append(torch.nn.Dropout(p=0.3))
|
||||
|
||||
self.linear = torch.nn.Sequential(*linears)
|
||||
|
||||
if state_dict is not None:
|
||||
self.fix_old_state_dict(state_dict)
|
||||
self.load_state_dict(state_dict)
|
||||
else:
|
||||
for layer in self.linear:
|
||||
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
||||
w, b = layer.weight.data, layer.bias.data
|
||||
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
|
||||
normal_(w, mean=0.0, std=0.01)
|
||||
normal_(b, mean=0.0, std=0)
|
||||
elif weight_init == 'XavierUniform':
|
||||
xavier_uniform_(w)
|
||||
zeros_(b)
|
||||
elif weight_init == 'XavierNormal':
|
||||
xavier_normal_(w)
|
||||
zeros_(b)
|
||||
elif weight_init == 'KaimingUniform':
|
||||
kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
||||
zeros_(b)
|
||||
elif weight_init == 'KaimingNormal':
|
||||
kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
||||
zeros_(b)
|
||||
else:
|
||||
raise KeyError(f"Key {weight_init} is not defined as initialization!")
|
||||
self.to(devices.device)
|
||||
|
||||
def fix_old_state_dict(self, state_dict):
|
||||
changes = {
|
||||
'linear1.bias': 'linear.0.bias',
|
||||
'linear1.weight': 'linear.0.weight',
|
||||
'linear2.bias': 'linear.1.bias',
|
||||
'linear2.weight': 'linear.1.weight',
|
||||
}
|
||||
|
||||
for fr, to in changes.items():
|
||||
x = state_dict.get(fr, None)
|
||||
if x is None:
|
||||
continue
|
||||
|
||||
del state_dict[fr]
|
||||
state_dict[to] = x
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.linear(x) * self.multiplier
|
||||
|
||||
def trainables(self):
|
||||
layer_structure = []
|
||||
for layer in self.linear:
|
||||
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
||||
layer_structure += [layer.weight, layer.bias]
|
||||
return layer_structure
|
||||
|
||||
|
||||
def apply_strength(value=None):
|
||||
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
||||
|
||||
|
||||
class Hypernetwork:
|
||||
filename = None
|
||||
name = None
|
||||
|
||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
|
||||
self.filename = None
|
||||
self.name = name
|
||||
self.layers = {}
|
||||
self.step = 0
|
||||
self.sd_checkpoint = None
|
||||
self.sd_checkpoint_name = None
|
||||
self.layer_structure = layer_structure
|
||||
self.activation_func = activation_func
|
||||
self.weight_init = weight_init
|
||||
self.add_layer_norm = add_layer_norm
|
||||
self.use_dropout = use_dropout
|
||||
self.activate_output = activate_output
|
||||
self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
|
||||
self.optimizer_name = None
|
||||
self.optimizer_state_dict = None
|
||||
|
||||
for size in enable_sizes or []:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
||||
)
|
||||
self.eval_mode()
|
||||
|
||||
def weights(self):
|
||||
res = []
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
res += layer.parameters()
|
||||
return res
|
||||
|
||||
def train_mode(self):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.train()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
def eval_mode(self):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def save(self, filename):
|
||||
state_dict = {}
|
||||
optimizer_saved_dict = {}
|
||||
|
||||
for k, v in self.layers.items():
|
||||
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
|
||||
|
||||
state_dict['step'] = self.step
|
||||
state_dict['name'] = self.name
|
||||
state_dict['layer_structure'] = self.layer_structure
|
||||
state_dict['activation_func'] = self.activation_func
|
||||
state_dict['is_layer_norm'] = self.add_layer_norm
|
||||
state_dict['weight_initialization'] = self.weight_init
|
||||
state_dict['use_dropout'] = self.use_dropout
|
||||
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
||||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||
state_dict['activate_output'] = self.activate_output
|
||||
state_dict['last_layer_dropout'] = self.last_layer_dropout
|
||||
|
||||
if self.optimizer_name is not None:
|
||||
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
|
||||
|
||||
torch.save(state_dict, filename)
|
||||
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
|
||||
optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
|
||||
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
|
||||
torch.save(optimizer_saved_dict, filename + '.optim')
|
||||
|
||||
def load(self, filename):
|
||||
self.filename = filename
|
||||
if self.name is None:
|
||||
self.name = os.path.splitext(os.path.basename(filename))[0]
|
||||
|
||||
state_dict = torch.load(filename, map_location='cpu')
|
||||
|
||||
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||
print(self.layer_structure)
|
||||
self.activation_func = state_dict.get('activation_func', None)
|
||||
print(f"Activation function is {self.activation_func}")
|
||||
self.weight_init = state_dict.get('weight_initialization', 'Normal')
|
||||
print(f"Weight initialization is {self.weight_init}")
|
||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||
print(f"Layer norm is set to {self.add_layer_norm}")
|
||||
self.use_dropout = state_dict.get('use_dropout', False)
|
||||
print(f"Dropout usage is set to {self.use_dropout}" )
|
||||
self.activate_output = state_dict.get('activate_output', True)
|
||||
print(f"Activate last layer is set to {self.activate_output}")
|
||||
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
|
||||
|
||||
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
|
||||
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
||||
print(f"Optimizer name is {self.optimizer_name}")
|
||||
if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
|
||||
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
||||
else:
|
||||
self.optimizer_state_dict = None
|
||||
if self.optimizer_state_dict:
|
||||
print("Loaded existing optimizer from checkpoint")
|
||||
else:
|
||||
print("No saved optimizer exists in checkpoint")
|
||||
|
||||
for size, sd in state_dict.items():
|
||||
if type(size) == int:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
||||
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
||||
)
|
||||
|
||||
self.name = state_dict.get('name', self.name)
|
||||
self.step = state_dict.get('step', 0)
|
||||
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
||||
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
||||
|
||||
|
||||
def list_hypernetworks(path):
|
||||
res = {}
|
||||
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
|
||||
name = os.path.splitext(os.path.basename(filename))[0]
|
||||
# Prevent a hypothetical "None.pt" from being listed.
|
||||
if name != "None":
|
||||
res[name + f"({sd_models.model_hash(filename)})"] = filename
|
||||
return res
|
||||
|
||||
|
||||
def load_hypernetwork(filename):
|
||||
path = shared.hypernetworks.get(filename, None)
|
||||
# Prevent any file named "None.pt" from being loaded.
|
||||
if path is not None and filename != "None":
|
||||
print(f"Loading hypernetwork {filename}")
|
||||
try:
|
||||
shared.loaded_hypernetwork = Hypernetwork()
|
||||
shared.loaded_hypernetwork.load(path)
|
||||
|
||||
except Exception:
|
||||
print(f"Error loading hypernetwork {path}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
else:
|
||||
if shared.loaded_hypernetwork is not None:
|
||||
print("Unloading hypernetwork")
|
||||
|
||||
shared.loaded_hypernetwork = None
|
||||
|
||||
|
||||
def find_closest_hypernetwork_name(search: str):
|
||||
if not search:
|
||||
return None
|
||||
search = search.lower()
|
||||
applicable = [name for name in shared.hypernetworks if search in name.lower()]
|
||||
if not applicable:
|
||||
return None
|
||||
applicable = sorted(applicable, key=lambda name: len(name))
|
||||
return applicable[0]
|
||||
|
||||
|
||||
def apply_hypernetwork(hypernetwork, context, layer=None):
|
||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
||||
|
||||
if hypernetwork_layers is None:
|
||||
return context, context
|
||||
|
||||
if layer is not None:
|
||||
layer.hyper_k = hypernetwork_layers[0]
|
||||
layer.hyper_v = hypernetwork_layers[1]
|
||||
|
||||
context_k = hypernetwork_layers[0](context)
|
||||
context_v = hypernetwork_layers[1](context)
|
||||
return context_k, context_v
|
||||
|
||||
|
||||
def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
|
||||
k = self.to_k(context_k)
|
||||
v = self.to_v(context_v)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
if mask is not None:
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
def stack_conds(conds):
|
||||
if len(conds) == 1:
|
||||
return torch.stack(conds)
|
||||
|
||||
# same as in reconstruct_multicond_batch
|
||||
token_count = max([x.shape[0] for x in conds])
|
||||
for i in range(len(conds)):
|
||||
if conds[i].shape[0] != token_count:
|
||||
last_vector = conds[i][-1:]
|
||||
last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
|
||||
conds[i] = torch.vstack([conds[i], last_vector_repeated])
|
||||
|
||||
return torch.stack(conds)
|
||||
|
||||
|
||||
def statistics(data):
|
||||
if len(data) < 2:
|
||||
std = 0
|
||||
else:
|
||||
std = stdev(data)
|
||||
total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
|
||||
recent_data = data[-32:]
|
||||
if len(recent_data) < 2:
|
||||
std = 0
|
||||
else:
|
||||
std = stdev(recent_data)
|
||||
recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
|
||||
return total_information, recent_information
|
||||
|
||||
|
||||
def report_statistics(loss_info:dict):
|
||||
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
|
||||
for key in keys:
|
||||
try:
|
||||
print("Loss statistics for file " + key)
|
||||
info, recent = statistics(list(loss_info[key]))
|
||||
print(info)
|
||||
print(recent)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
||||
# Remove illegal characters from name.
|
||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||
|
||||
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||
if not overwrite_old:
|
||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||
|
||||
if type(layer_structure) == str:
|
||||
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
||||
|
||||
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
||||
name=name,
|
||||
enable_sizes=[int(x) for x in enable_sizes],
|
||||
layer_structure=layer_structure,
|
||||
activation_func=activation_func,
|
||||
weight_init=weight_init,
|
||||
add_layer_norm=add_layer_norm,
|
||||
use_dropout=use_dropout,
|
||||
)
|
||||
hypernet.save(fn)
|
||||
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
from modules import images
|
||||
|
||||
save_hypernetwork_every = save_hypernetwork_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
|
||||
|
||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||
shared.loaded_hypernetwork = Hypernetwork()
|
||||
shared.loaded_hypernetwork.load(path)
|
||||
|
||||
shared.state.job = "train-hypernetwork"
|
||||
shared.state.textinfo = "Initializing hypernetwork training..."
|
||||
shared.state.job_count = steps
|
||||
|
||||
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
|
||||
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
|
||||
unload = shared.opts.unload_models_when_training
|
||||
|
||||
if save_hypernetwork_every > 0:
|
||||
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
|
||||
os.makedirs(hypernetwork_dir, exist_ok=True)
|
||||
else:
|
||||
hypernetwork_dir = None
|
||||
|
||||
if create_image_every > 0:
|
||||
images_dir = os.path.join(log_directory, "images")
|
||||
os.makedirs(images_dir, exist_ok=True)
|
||||
else:
|
||||
images_dir = None
|
||||
|
||||
hypernetwork = shared.loaded_hypernetwork
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
|
||||
initial_step = hypernetwork.step or 0
|
||||
if initial_step >= steps:
|
||||
shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
||||
return hypernetwork, filename
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
|
||||
pin_memory = shared.opts.pin_memory
|
||||
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
|
||||
|
||||
latent_sampling_method = ds.latent_sampling_method
|
||||
|
||||
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
||||
|
||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||
|
||||
if unload:
|
||||
shared.parallel_processing_allowed = False
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
weights = hypernetwork.weights()
|
||||
hypernetwork.train_mode()
|
||||
|
||||
# Here we use optimizer from saved HN, or we can specify as UI option.
|
||||
if hypernetwork.optimizer_name in optimizer_dict:
|
||||
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
|
||||
optimizer_name = hypernetwork.optimizer_name
|
||||
else:
|
||||
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
|
||||
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
|
||||
optimizer_name = 'AdamW'
|
||||
|
||||
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
||||
try:
|
||||
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
||||
except RuntimeError as e:
|
||||
print("Cannot resume from saved optimizer!")
|
||||
print(e)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
batch_size = ds.batch_size
|
||||
gradient_step = ds.gradient_step
|
||||
# n steps = batch_size * gradient_step * n image processed
|
||||
steps_per_epoch = len(ds) // batch_size // gradient_step
|
||||
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
|
||||
loss_step = 0
|
||||
_loss_step = 0 #internal
|
||||
# size = len(ds.indexes)
|
||||
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
||||
# losses = torch.zeros((size,))
|
||||
# previous_mean_losses = [0]
|
||||
# previous_mean_loss = 0
|
||||
# print("Mean loss of {} elements".format(size))
|
||||
|
||||
steps_without_grad = 0
|
||||
|
||||
last_saved_file = "<none>"
|
||||
last_saved_image = "<none>"
|
||||
forced_filename = "<none>"
|
||||
|
||||
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||
try:
|
||||
for i in range((steps-initial_step) * gradient_step):
|
||||
if scheduler.finished:
|
||||
break
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
for j, batch in enumerate(dl):
|
||||
# works as a drop_last=True for gradient accumulation
|
||||
if j == max_steps_per_epoch:
|
||||
break
|
||||
scheduler.apply(optimizer, hypernetwork.step)
|
||||
if scheduler.finished:
|
||||
break
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
||||
with devices.autocast():
|
||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||
if tag_drop_out != 0 or shuffle_tags:
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
else:
|
||||
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
||||
loss = shared.sd_model(x, c)[0] / gradient_step
|
||||
del x
|
||||
del c
|
||||
|
||||
_loss_step += loss.item()
|
||||
scaler.scale(loss).backward()
|
||||
# go back until we reach gradient accumulation steps
|
||||
if (j + 1) % gradient_step != 0:
|
||||
continue
|
||||
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
|
||||
# scaler.unscale_(optimizer)
|
||||
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
||||
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
|
||||
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
hypernetwork.step += 1
|
||||
pbar.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss_step = _loss_step
|
||||
_loss_step = 0
|
||||
|
||||
steps_done = hypernetwork.step + 1
|
||||
|
||||
epoch_num = hypernetwork.step // steps_per_epoch
|
||||
epoch_step = hypernetwork.step % steps_per_epoch
|
||||
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
|
||||
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
||||
# Before saving, change name to match current checkpoint.
|
||||
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
||||
hypernetwork.optimizer_name = optimizer_name
|
||||
if shared.opts.save_optimizer_state:
|
||||
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||
|
||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
|
||||
"loss": f"{loss_step:.7f}",
|
||||
"learn_rate": scheduler.learn_rate
|
||||
})
|
||||
|
||||
if images_dir is not None and steps_done % create_image_every == 0:
|
||||
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
hypernetwork.eval_mode()
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
do_not_save_grid=True,
|
||||
do_not_save_samples=True,
|
||||
)
|
||||
|
||||
if preview_from_txt2img:
|
||||
p.prompt = preview_prompt
|
||||
p.negative_prompt = preview_negative_prompt
|
||||
p.steps = preview_steps
|
||||
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
||||
p.cfg_scale = preview_cfg_scale
|
||||
p.seed = preview_seed
|
||||
p.width = preview_width
|
||||
p.height = preview_height
|
||||
else:
|
||||
p.prompt = batch.cond_text[0]
|
||||
p.steps = 20
|
||||
p.width = training_width
|
||||
p.height = training_height
|
||||
|
||||
preview_text = p.prompt
|
||||
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0] if len(processed.images) > 0 else None
|
||||
|
||||
if unload:
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
hypernetwork.train_mode()
|
||||
if image is not None:
|
||||
shared.state.current_image = image
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
shared.state.job_no = hypernetwork.step
|
||||
|
||||
shared.state.textinfo = f"""
|
||||
<p>
|
||||
Loss: {loss_step:.7f}<br/>
|
||||
Step: {steps_done}<br/>
|
||||
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
||||
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
"""
|
||||
except Exception:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
finally:
|
||||
pbar.leave = False
|
||||
pbar.close()
|
||||
hypernetwork.eval_mode()
|
||||
#report_statistics(loss_dict)
|
||||
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
hypernetwork.optimizer_name = optimizer_name
|
||||
if shared.opts.save_optimizer_state:
|
||||
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
||||
|
||||
del optimizer
|
||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||
|
||||
return hypernetwork, filename
|
||||
|
||||
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
|
||||
old_hypernetwork_name = hypernetwork.name
|
||||
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
|
||||
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
|
||||
try:
|
||||
hypernetwork.sd_checkpoint = checkpoint.hash
|
||||
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
||||
hypernetwork.name = hypernetwork_name
|
||||
hypernetwork.save(filename)
|
||||
except:
|
||||
hypernetwork.sd_checkpoint = old_sd_checkpoint
|
||||
hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
|
||||
hypernetwork.name = old_hypernetwork_name
|
||||
raise
|
||||
@ -0,0 +1,41 @@
|
||||
import html
|
||||
import os
|
||||
import re
|
||||
|
||||
import gradio as gr
|
||||
import modules.hypernetworks.hypernetwork
|
||||
from modules import devices, sd_hijack, shared
|
||||
|
||||
not_available = ["hardswish", "multiheadattention"]
|
||||
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
||||
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout)
|
||||
|
||||
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
|
||||
|
||||
|
||||
def train_hypernetwork(*args):
|
||||
|
||||
initial_hypernetwork = shared.loaded_hypernetwork
|
||||
|
||||
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
||||
|
||||
try:
|
||||
sd_hijack.undo_optimizations()
|
||||
|
||||
hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
|
||||
|
||||
res = f"""
|
||||
Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
|
||||
Hypernetwork saved to {html.escape(filename)}
|
||||
"""
|
||||
return res, ""
|
||||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
shared.loaded_hypernetwork = initial_hypernetwork
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
sd_hijack.apply_optimizations()
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
import sys
|
||||
|
||||
# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
|
||||
if "--xformers" not in "".join(sys.argv):
|
||||
sys.modules["xformers"] = None
|
||||
@ -0,0 +1,37 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
|
||||
localizations = {}
|
||||
|
||||
|
||||
def list_localizations(dirname):
|
||||
localizations.clear()
|
||||
|
||||
for file in os.listdir(dirname):
|
||||
fn, ext = os.path.splitext(file)
|
||||
if ext.lower() != ".json":
|
||||
continue
|
||||
|
||||
localizations[fn] = os.path.join(dirname, file)
|
||||
|
||||
from modules import scripts
|
||||
for file in scripts.list_scripts("localizations", ".json"):
|
||||
fn, ext = os.path.splitext(file.filename)
|
||||
localizations[fn] = file.path
|
||||
|
||||
|
||||
def localization_js(current_localization_name):
|
||||
fn = localizations.get(current_localization_name, None)
|
||||
data = {}
|
||||
if fn is not None:
|
||||
try:
|
||||
with open(fn, "r", encoding="utf8") as file:
|
||||
data = json.load(file)
|
||||
except Exception:
|
||||
print(f"Error loading localization from {fn}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
return f"var localization = {json.dumps(data)}\n"
|
||||
@ -0,0 +1,26 @@
|
||||
from pyngrok import ngrok, conf, exception
|
||||
|
||||
def connect(token, port, region):
|
||||
account = None
|
||||
if token is None:
|
||||
token = 'None'
|
||||
else:
|
||||
if ':' in token:
|
||||
# token = authtoken:username:password
|
||||
account = token.split(':')[1] + ':' + token.split(':')[-1]
|
||||
token = token.split(':')[0]
|
||||
|
||||
config = conf.PyngrokConfig(
|
||||
auth_token=token, region=region
|
||||
)
|
||||
try:
|
||||
if account is None:
|
||||
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
|
||||
else:
|
||||
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url
|
||||
except exception.PyngrokNgrokError:
|
||||
print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
|
||||
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
|
||||
else:
|
||||
print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
|
||||
'You can use this link after the launch is complete.')
|
||||
@ -0,0 +1,192 @@
|
||||
# this code is adapted from the script contributed by anon from /h/
|
||||
|
||||
import io
|
||||
import pickle
|
||||
import collections
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
import numpy
|
||||
import _codecs
|
||||
import zipfile
|
||||
import re
|
||||
|
||||
|
||||
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
||||
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
||||
|
||||
|
||||
def encode(*args):
|
||||
out = _codecs.encode(*args)
|
||||
return out
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
extra_handler = None
|
||||
|
||||
def persistent_load(self, saved_id):
|
||||
assert saved_id[0] == 'storage'
|
||||
return TypedStorage()
|
||||
|
||||
def find_class(self, module, name):
|
||||
if self.extra_handler is not None:
|
||||
res = self.extra_handler(module, name)
|
||||
if res is not None:
|
||||
return res
|
||||
|
||||
if module == 'collections' and name == 'OrderedDict':
|
||||
return getattr(collections, name)
|
||||
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
|
||||
return getattr(torch._utils, name)
|
||||
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
|
||||
return getattr(torch, name)
|
||||
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
||||
return getattr(torch.nn.modules.container, name)
|
||||
if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
|
||||
return getattr(numpy.core.multiarray, name)
|
||||
if module == 'numpy' and name in ['dtype', 'ndarray']:
|
||||
return getattr(numpy, name)
|
||||
if module == '_codecs' and name == 'encode':
|
||||
return encode
|
||||
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
|
||||
import pytorch_lightning.callbacks
|
||||
return pytorch_lightning.callbacks.model_checkpoint
|
||||
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
|
||||
import pytorch_lightning.callbacks.model_checkpoint
|
||||
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
|
||||
if module == "__builtin__" and name == 'set':
|
||||
return set
|
||||
|
||||
# Forbid everything else.
|
||||
raise Exception(f"global '{module}/{name}' is forbidden")
|
||||
|
||||
|
||||
# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
|
||||
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
|
||||
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
|
||||
|
||||
def check_zip_filenames(filename, names):
|
||||
for name in names:
|
||||
if allowed_zip_names_re.match(name):
|
||||
continue
|
||||
|
||||
raise Exception(f"bad file inside {filename}: {name}")
|
||||
|
||||
|
||||
def check_pt(filename, extra_handler):
|
||||
try:
|
||||
|
||||
# new pytorch format is a zip file
|
||||
with zipfile.ZipFile(filename) as z:
|
||||
check_zip_filenames(filename, z.namelist())
|
||||
|
||||
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
|
||||
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
|
||||
if len(data_pkl_filenames) == 0:
|
||||
raise Exception(f"data.pkl not found in {filename}")
|
||||
if len(data_pkl_filenames) > 1:
|
||||
raise Exception(f"Multiple data.pkl found in {filename}")
|
||||
with z.open(data_pkl_filenames[0]) as file:
|
||||
unpickler = RestrictedUnpickler(file)
|
||||
unpickler.extra_handler = extra_handler
|
||||
unpickler.load()
|
||||
|
||||
except zipfile.BadZipfile:
|
||||
|
||||
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
|
||||
with open(filename, "rb") as file:
|
||||
unpickler = RestrictedUnpickler(file)
|
||||
unpickler.extra_handler = extra_handler
|
||||
for i in range(5):
|
||||
unpickler.load()
|
||||
|
||||
|
||||
def load(filename, *args, **kwargs):
|
||||
return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
|
||||
|
||||
|
||||
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||
"""
|
||||
this function is intended to be used by extensions that want to load models with
|
||||
some extra classes in them that the usual unpickler would find suspicious.
|
||||
|
||||
Use the extra_handler argument to specify a function that takes module and field name as text,
|
||||
and returns that field's value:
|
||||
|
||||
```python
|
||||
def extra(module, name):
|
||||
if module == 'collections' and name == 'OrderedDict':
|
||||
return collections.OrderedDict
|
||||
|
||||
return None
|
||||
|
||||
safe.load_with_extra('model.pt', extra_handler=extra)
|
||||
```
|
||||
|
||||
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
|
||||
definitely unsafe.
|
||||
"""
|
||||
|
||||
from modules import shared
|
||||
|
||||
try:
|
||||
if not shared.cmd_opts.disable_safe_unpickle:
|
||||
check_pt(filename, extra_handler)
|
||||
|
||||
except pickle.UnpicklingError:
|
||||
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
|
||||
print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
|
||||
print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
|
||||
return None
|
||||
|
||||
return unsafe_torch_load(filename, *args, **kwargs)
|
||||
|
||||
|
||||
class Extra:
|
||||
"""
|
||||
A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
|
||||
(because it's not your code making the torch.load call). The intended use is like this:
|
||||
|
||||
```
|
||||
import torch
|
||||
from modules import safe
|
||||
|
||||
def handler(module, name):
|
||||
if module == 'torch' and name in ['float64', 'float16']:
|
||||
return getattr(torch, name)
|
||||
|
||||
return None
|
||||
|
||||
with safe.Extra(handler):
|
||||
x = torch.load('model.pt')
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, handler):
|
||||
self.handler = handler
|
||||
|
||||
def __enter__(self):
|
||||
global global_extra_handler
|
||||
|
||||
assert global_extra_handler is None, 'already inside an Extra() block'
|
||||
global_extra_handler = self.handler
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
global global_extra_handler
|
||||
|
||||
global_extra_handler = None
|
||||
|
||||
|
||||
unsafe_torch_load = torch.load
|
||||
torch.load = load
|
||||
global_extra_handler = None
|
||||
|
||||
@ -1,42 +0,0 @@
|
||||
import torch
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
from PIL import Image
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_feature_extractor = None
|
||||
safety_checker = None
|
||||
|
||||
def numpy_to_pil(images):
|
||||
"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
images = (images * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
|
||||
return pil_images
|
||||
|
||||
# check and replace nsfw content
|
||||
def check_safety(x_image):
|
||||
global safety_feature_extractor, safety_checker
|
||||
|
||||
if safety_feature_extractor is None:
|
||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
||||
|
||||
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
||||
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
||||
|
||||
return x_checked_image, has_nsfw_concept
|
||||
|
||||
|
||||
def censor_batch(x):
|
||||
x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
|
||||
x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
|
||||
x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
|
||||
|
||||
return x
|
||||
@ -0,0 +1,281 @@
|
||||
import sys
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
import inspect
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI
|
||||
from gradio import Blocks
|
||||
|
||||
|
||||
def report_exception(c, job):
|
||||
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
class ImageSaveParams:
|
||||
def __init__(self, image, p, filename, pnginfo):
|
||||
self.image = image
|
||||
"""the PIL image itself"""
|
||||
|
||||
self.p = p
|
||||
"""p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
|
||||
|
||||
self.filename = filename
|
||||
"""name of file that the image would be saved to"""
|
||||
|
||||
self.pnginfo = pnginfo
|
||||
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
|
||||
|
||||
|
||||
class CFGDenoiserParams:
|
||||
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
|
||||
self.x = x
|
||||
"""Latent image representation in the process of being denoised"""
|
||||
|
||||
self.image_cond = image_cond
|
||||
"""Conditioning image"""
|
||||
|
||||
self.sigma = sigma
|
||||
"""Current sigma noise step value"""
|
||||
|
||||
self.sampling_step = sampling_step
|
||||
"""Current Sampling step number"""
|
||||
|
||||
self.total_sampling_steps = total_sampling_steps
|
||||
"""Total number of sampling steps planned"""
|
||||
|
||||
|
||||
class UiTrainTabParams:
|
||||
def __init__(self, txt2img_preview_params):
|
||||
self.txt2img_preview_params = txt2img_preview_params
|
||||
|
||||
|
||||
class ImageGridLoopParams:
|
||||
def __init__(self, imgs, cols, rows):
|
||||
self.imgs = imgs
|
||||
self.cols = cols
|
||||
self.rows = rows
|
||||
|
||||
|
||||
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
||||
callback_map = dict(
|
||||
callbacks_app_started=[],
|
||||
callbacks_model_loaded=[],
|
||||
callbacks_ui_tabs=[],
|
||||
callbacks_ui_train_tabs=[],
|
||||
callbacks_ui_settings=[],
|
||||
callbacks_before_image_saved=[],
|
||||
callbacks_image_saved=[],
|
||||
callbacks_cfg_denoiser=[],
|
||||
callbacks_before_component=[],
|
||||
callbacks_after_component=[],
|
||||
callbacks_image_grid=[],
|
||||
)
|
||||
|
||||
|
||||
def clear_callbacks():
|
||||
for callback_list in callback_map.values():
|
||||
callback_list.clear()
|
||||
|
||||
|
||||
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
||||
for c in callback_map['callbacks_app_started']:
|
||||
try:
|
||||
c.callback(demo, app)
|
||||
except Exception:
|
||||
report_exception(c, 'app_started_callback')
|
||||
|
||||
|
||||
def model_loaded_callback(sd_model):
|
||||
for c in callback_map['callbacks_model_loaded']:
|
||||
try:
|
||||
c.callback(sd_model)
|
||||
except Exception:
|
||||
report_exception(c, 'model_loaded_callback')
|
||||
|
||||
|
||||
def ui_tabs_callback():
|
||||
res = []
|
||||
|
||||
for c in callback_map['callbacks_ui_tabs']:
|
||||
try:
|
||||
res += c.callback() or []
|
||||
except Exception:
|
||||
report_exception(c, 'ui_tabs_callback')
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def ui_train_tabs_callback(params: UiTrainTabParams):
|
||||
for c in callback_map['callbacks_ui_train_tabs']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'callbacks_ui_train_tabs')
|
||||
|
||||
|
||||
def ui_settings_callback():
|
||||
for c in callback_map['callbacks_ui_settings']:
|
||||
try:
|
||||
c.callback()
|
||||
except Exception:
|
||||
report_exception(c, 'ui_settings_callback')
|
||||
|
||||
|
||||
def before_image_saved_callback(params: ImageSaveParams):
|
||||
for c in callback_map['callbacks_before_image_saved']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'before_image_saved_callback')
|
||||
|
||||
|
||||
def image_saved_callback(params: ImageSaveParams):
|
||||
for c in callback_map['callbacks_image_saved']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'image_saved_callback')
|
||||
|
||||
|
||||
def cfg_denoiser_callback(params: CFGDenoiserParams):
|
||||
for c in callback_map['callbacks_cfg_denoiser']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'cfg_denoiser_callback')
|
||||
|
||||
|
||||
def before_component_callback(component, **kwargs):
|
||||
for c in callback_map['callbacks_before_component']:
|
||||
try:
|
||||
c.callback(component, **kwargs)
|
||||
except Exception:
|
||||
report_exception(c, 'before_component_callback')
|
||||
|
||||
|
||||
def after_component_callback(component, **kwargs):
|
||||
for c in callback_map['callbacks_after_component']:
|
||||
try:
|
||||
c.callback(component, **kwargs)
|
||||
except Exception:
|
||||
report_exception(c, 'after_component_callback')
|
||||
|
||||
|
||||
def image_grid_callback(params: ImageGridLoopParams):
|
||||
for c in callback_map['callbacks_image_grid']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'image_grid')
|
||||
|
||||
|
||||
def add_callback(callbacks, fun):
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||
|
||||
callbacks.append(ScriptCallback(filename, fun))
|
||||
|
||||
|
||||
def remove_current_script_callbacks():
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||
if filename == 'unknown file':
|
||||
return
|
||||
for callback_list in callback_map.values():
|
||||
for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
|
||||
callback_list.remove(callback_to_remove)
|
||||
|
||||
|
||||
def remove_callbacks_for_function(callback_func):
|
||||
for callback_list in callback_map.values():
|
||||
for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
|
||||
callback_list.remove(callback_to_remove)
|
||||
|
||||
|
||||
def on_app_started(callback):
|
||||
"""register a function to be called when the webui started, the gradio `Block` component and
|
||||
fastapi `FastAPI` object are passed as the arguments"""
|
||||
add_callback(callback_map['callbacks_app_started'], callback)
|
||||
|
||||
|
||||
def on_model_loaded(callback):
|
||||
"""register a function to be called when the stable diffusion model is created; the model is
|
||||
passed as an argument"""
|
||||
add_callback(callback_map['callbacks_model_loaded'], callback)
|
||||
|
||||
|
||||
def on_ui_tabs(callback):
|
||||
"""register a function to be called when the UI is creating new tabs.
|
||||
The function must either return a None, which means no new tabs to be added, or a list, where
|
||||
each element is a tuple:
|
||||
(gradio_component, title, elem_id)
|
||||
|
||||
gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
|
||||
title is tab text displayed to user in the UI
|
||||
elem_id is HTML id for the tab
|
||||
"""
|
||||
add_callback(callback_map['callbacks_ui_tabs'], callback)
|
||||
|
||||
|
||||
def on_ui_train_tabs(callback):
|
||||
"""register a function to be called when the UI is creating new tabs for the train tab.
|
||||
Create your new tabs with gr.Tab.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_ui_train_tabs'], callback)
|
||||
|
||||
|
||||
def on_ui_settings(callback):
|
||||
"""register a function to be called before UI settings are populated; add your settings
|
||||
by using shared.opts.add_option(shared.OptionInfo(...)) """
|
||||
add_callback(callback_map['callbacks_ui_settings'], callback)
|
||||
|
||||
|
||||
def on_before_image_saved(callback):
|
||||
"""register a function to be called before an image is saved to a file.
|
||||
The callback is called with one argument:
|
||||
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_before_image_saved'], callback)
|
||||
|
||||
|
||||
def on_image_saved(callback):
|
||||
"""register a function to be called after an image is saved to a file.
|
||||
The callback is called with one argument:
|
||||
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_image_saved'], callback)
|
||||
|
||||
|
||||
def on_cfg_denoiser(callback):
|
||||
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
|
||||
The callback is called with one argument:
|
||||
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_cfg_denoiser'], callback)
|
||||
|
||||
|
||||
def on_before_component(callback):
|
||||
"""register a function to be called before a component is created.
|
||||
The callback is called with arguments:
|
||||
- component - gradio component that is about to be created.
|
||||
- **kwargs - args to gradio.components.IOComponent.__init__ function
|
||||
|
||||
Use elem_id/label fields of kwargs to figure out which component it is.
|
||||
This can be useful to inject your own components somewhere in the middle of vanilla UI.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_before_component'], callback)
|
||||
|
||||
|
||||
def on_after_component(callback):
|
||||
"""register a function to be called after a component is created. See on_before_component for more."""
|
||||
add_callback(callback_map['callbacks_after_component'], callback)
|
||||
|
||||
|
||||
def on_image_grid(callback):
|
||||
"""register a function to be called before making an image grid.
|
||||
The callback is called with one argument:
|
||||
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_image_grid'], callback)
|
||||
@ -0,0 +1,34 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
def load_module(path):
|
||||
with open(path, "r", encoding="utf8") as file:
|
||||
text = file.read()
|
||||
|
||||
compiled = compile(text, path, 'exec')
|
||||
module = ModuleType(os.path.basename(path))
|
||||
exec(compiled, module.__dict__)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def preload_extensions(extensions_dir, parser):
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
for dirname in sorted(os.listdir(extensions_dir)):
|
||||
preload_script = os.path.join(extensions_dir, dirname, "preload.py")
|
||||
if not os.path.isfile(preload_script):
|
||||
continue
|
||||
|
||||
try:
|
||||
module = load_module(preload_script)
|
||||
if hasattr(module, 'preload'):
|
||||
module.preload(parser)
|
||||
|
||||
except Exception:
|
||||
print(f"Error running preload() for {preload_script}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
@ -0,0 +1,10 @@
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
def BasicTransformerBlock_forward(self, x, context=None):
|
||||
return checkpoint(self._forward, x, context)
|
||||
|
||||
def AttentionBlock_forward(self, x):
|
||||
return checkpoint(self._forward, x)
|
||||
|
||||
def ResBlock_forward(self, x, emb):
|
||||
return checkpoint(self._forward, x, emb)
|
||||
@ -0,0 +1,303 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from modules import prompt_parser, devices
|
||||
from modules.shared import opts
|
||||
|
||||
def get_target_prompt_token_count(token_count):
|
||||
return math.ceil(max(token_count, 1) / 75) * 75
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__()
|
||||
self.wrapped = wrapped
|
||||
self.hijack = hijack
|
||||
|
||||
def tokenize(self, texts):
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
||||
if opts.enable_emphasis:
|
||||
parsed = prompt_parser.parse_prompt_attention(line)
|
||||
else:
|
||||
parsed = [[line, 1.0]]
|
||||
|
||||
tokenized = self.tokenize([text for text, _ in parsed])
|
||||
|
||||
fixes = []
|
||||
remade_tokens = []
|
||||
multipliers = []
|
||||
last_comma = -1
|
||||
|
||||
for tokens, (text, weight) in zip(tokenized, parsed):
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
if token == self.comma_token:
|
||||
last_comma = len(remade_tokens)
|
||||
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
|
||||
last_comma += 1
|
||||
reloc_tokens = remade_tokens[last_comma:]
|
||||
reloc_mults = multipliers[last_comma:]
|
||||
|
||||
remade_tokens = remade_tokens[:last_comma]
|
||||
length = len(remade_tokens)
|
||||
|
||||
rem = int(math.ceil(length / 75)) * 75 - length
|
||||
remade_tokens += [self.id_end] * rem + reloc_tokens
|
||||
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
|
||||
|
||||
if embedding is None:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(weight)
|
||||
i += 1
|
||||
else:
|
||||
emb_len = int(embedding.vec.shape[0])
|
||||
iteration = len(remade_tokens) // 75
|
||||
if (len(remade_tokens) + emb_len) // 75 != iteration:
|
||||
rem = (75 * (iteration + 1) - len(remade_tokens))
|
||||
remade_tokens += [self.id_end] * rem
|
||||
multipliers += [1.0] * rem
|
||||
iteration += 1
|
||||
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [weight] * emb_len
|
||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||
i += embedding_length_in_tokens
|
||||
|
||||
token_count = len(remade_tokens)
|
||||
prompt_target_length = get_target_prompt_token_count(token_count)
|
||||
tokens_to_add = prompt_target_length - len(remade_tokens)
|
||||
|
||||
remade_tokens = remade_tokens + [self.id_end] * tokens_to_add
|
||||
multipliers = multipliers + [1.0] * tokens_to_add
|
||||
|
||||
return remade_tokens, fixes, multipliers, token_count
|
||||
|
||||
def process_text(self, texts):
|
||||
used_custom_terms = []
|
||||
remade_batch_tokens = []
|
||||
hijack_comments = []
|
||||
hijack_fixes = []
|
||||
token_count = 0
|
||||
|
||||
cache = {}
|
||||
batch_multipliers = []
|
||||
for line in texts:
|
||||
if line in cache:
|
||||
remade_tokens, fixes, multipliers = cache[line]
|
||||
else:
|
||||
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
|
||||
token_count = max(current_token_count, token_count)
|
||||
|
||||
cache[line] = (remade_tokens, fixes, multipliers)
|
||||
|
||||
remade_batch_tokens.append(remade_tokens)
|
||||
hijack_fixes.append(fixes)
|
||||
batch_multipliers.append(multipliers)
|
||||
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
def process_text_old(self, texts):
|
||||
id_start = self.id_start
|
||||
id_end = self.id_end
|
||||
maxlen = self.wrapped.max_length # you get to stay at 77
|
||||
used_custom_terms = []
|
||||
remade_batch_tokens = []
|
||||
hijack_comments = []
|
||||
hijack_fixes = []
|
||||
token_count = 0
|
||||
|
||||
cache = {}
|
||||
batch_tokens = self.tokenize(texts)
|
||||
batch_multipliers = []
|
||||
for tokens in batch_tokens:
|
||||
tuple_tokens = tuple(tokens)
|
||||
|
||||
if tuple_tokens in cache:
|
||||
remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
||||
else:
|
||||
fixes = []
|
||||
remade_tokens = []
|
||||
multipliers = []
|
||||
mult = 1.0
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
||||
if mult_change is not None:
|
||||
mult *= mult_change
|
||||
i += 1
|
||||
elif embedding is None:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(mult)
|
||||
i += 1
|
||||
else:
|
||||
emb_len = int(embedding.vec.shape[0])
|
||||
fixes.append((len(remade_tokens), embedding))
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [mult] * emb_len
|
||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||
i += embedding_length_in_tokens
|
||||
|
||||
if len(remade_tokens) > maxlen - 2:
|
||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||
ovf = remade_tokens[maxlen - 2:]
|
||||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||
|
||||
token_count = len(remade_tokens)
|
||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
||||
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||
|
||||
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
||||
|
||||
remade_batch_tokens.append(remade_tokens)
|
||||
hijack_fixes.append(fixes)
|
||||
batch_multipliers.append(multipliers)
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
def forward(self, text):
|
||||
use_old = opts.use_old_emphasis_implementation
|
||||
if use_old:
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
||||
else:
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
||||
|
||||
self.hijack.comments += hijack_comments
|
||||
|
||||
if len(used_custom_terms) > 0:
|
||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||
|
||||
if use_old:
|
||||
self.hijack.fixes = hijack_fixes
|
||||
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
||||
|
||||
z = None
|
||||
i = 0
|
||||
while max(map(len, remade_batch_tokens)) != 0:
|
||||
rem_tokens = [x[75:] for x in remade_batch_tokens]
|
||||
rem_multipliers = [x[75:] for x in batch_multipliers]
|
||||
|
||||
self.hijack.fixes = []
|
||||
for unfiltered in hijack_fixes:
|
||||
fixes = []
|
||||
for fix in unfiltered:
|
||||
if fix[0] == i:
|
||||
fixes.append(fix[1])
|
||||
self.hijack.fixes.append(fixes)
|
||||
|
||||
tokens = []
|
||||
multipliers = []
|
||||
for j in range(len(remade_batch_tokens)):
|
||||
if len(remade_batch_tokens[j]) > 0:
|
||||
tokens.append(remade_batch_tokens[j][:75])
|
||||
multipliers.append(batch_multipliers[j][:75])
|
||||
else:
|
||||
tokens.append([self.id_end] * 75)
|
||||
multipliers.append([1.0] * 75)
|
||||
|
||||
z1 = self.process_tokens(tokens, multipliers)
|
||||
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
||||
|
||||
remade_batch_tokens = rem_tokens
|
||||
batch_multipliers = rem_multipliers
|
||||
i += 1
|
||||
|
||||
return z
|
||||
|
||||
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||
if not opts.use_old_emphasis_implementation:
|
||||
remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens]
|
||||
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
|
||||
|
||||
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
|
||||
|
||||
if self.id_end != self.id_pad:
|
||||
for batch_pos in range(len(remade_batch_tokens)):
|
||||
index = remade_batch_tokens[batch_pos].index(self.id_end)
|
||||
tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
|
||||
|
||||
z = self.encode_with_transformers(tokens)
|
||||
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
|
||||
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device)
|
||||
original_mean = z.mean()
|
||||
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||
new_mean = z.mean()
|
||||
z *= original_mean / new_mean
|
||||
|
||||
return z
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
self.tokenizer = wrapped.tokenizer
|
||||
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
|
||||
self.comma_token = vocab.get(',</w>', None)
|
||||
|
||||
self.token_mults = {}
|
||||
tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||
for text, ident in tokens_with_parens:
|
||||
mult = 1.0
|
||||
for c in text:
|
||||
if c == '[':
|
||||
mult /= 1.1
|
||||
if c == ']':
|
||||
mult *= 1.1
|
||||
if c == '(':
|
||||
mult *= 1.1
|
||||
if c == ')':
|
||||
mult /= 1.1
|
||||
|
||||
if mult != 1.0:
|
||||
self.token_mults[ident] = mult
|
||||
|
||||
self.id_start = self.wrapped.tokenizer.bos_token_id
|
||||
self.id_end = self.wrapped.tokenizer.eos_token_id
|
||||
self.id_pad = self.id_end
|
||||
|
||||
def tokenize(self, texts):
|
||||
tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
|
||||
return tokenized
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||
|
||||
if opts.CLIP_stop_at_last_layers > 1:
|
||||
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
|
||||
z = self.wrapped.transformer.text_model.final_layer_norm(z)
|
||||
else:
|
||||
z = outputs.last_hidden_state
|
||||
|
||||
return z
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
embedding_layer = self.wrapped.transformer.text_model.embeddings
|
||||
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
|
||||
|
||||
return embedded
|
||||
@ -0,0 +1,111 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
from einops import repeat
|
||||
from omegaconf import ListConfig
|
||||
|
||||
import ldm.models.diffusion.ddpm
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [
|
||||
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||
for i in range(len(c[k]))
|
||||
]
|
||||
else:
|
||||
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
if dynamic_threshold is not None:
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
e_t = get_model_output(x, t)
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = get_model_output(x_prev, t_next)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
return x_prev, pred_x0, e_t
|
||||
|
||||
|
||||
def should_hijack_inpainting(checkpoint_info):
|
||||
from modules import sd_models
|
||||
|
||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||
cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
|
||||
|
||||
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
|
||||
|
||||
|
||||
def do_inpainting_hijack():
|
||||
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
||||
|
||||
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
||||
@ -0,0 +1,37 @@
|
||||
import open_clip.tokenizer
|
||||
import torch
|
||||
|
||||
from modules import sd_hijack_clip, devices
|
||||
from modules.shared import opts
|
||||
|
||||
tokenizer = open_clip.tokenizer._tokenizer
|
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
|
||||
self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
|
||||
self.id_start = tokenizer.encoder["<start_of_text>"]
|
||||
self.id_end = tokenizer.encoder["<end_of_text>"]
|
||||
self.id_pad = 0
|
||||
|
||||
def tokenize(self, texts):
|
||||
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
|
||||
|
||||
tokenized = [tokenizer.encode(text) for text in texts]
|
||||
|
||||
return tokenized
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
# set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
|
||||
z = self.wrapped.encode_with_transformer(tokens)
|
||||
|
||||
return z
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
ids = tokenizer.encode(init_text)
|
||||
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
||||
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
|
||||
|
||||
return embedded
|
||||
@ -0,0 +1,314 @@
|
||||
import math
|
||||
import sys
|
||||
import traceback
|
||||
import importlib
|
||||
|
||||
import torch
|
||||
from torch import einsum
|
||||
|
||||
from ldm.util import default
|
||||
from einops import rearrange
|
||||
|
||||
from modules import shared
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
|
||||
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
||||
try:
|
||||
import xformers.ops
|
||||
shared.xformers_available = True
|
||||
except Exception:
|
||||
print("Cannot import xformers", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
||||
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
del context, context_k, context_v, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||
for i in range(0, q.shape[0], 2):
|
||||
end = i + 2
|
||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||
s1 *= self.scale
|
||||
|
||||
s2 = s1.softmax(dim=-1)
|
||||
del s1
|
||||
|
||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||
del s2
|
||||
del q, k, v
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
||||
return self.to_out(r2)
|
||||
|
||||
|
||||
# taken from https://github.com/Doggettx/stable-diffusion and modified
|
||||
def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
|
||||
k_in *= self.scale
|
||||
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
|
||||
del q, k, v
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
||||
return self.to_out(r2)
|
||||
|
||||
|
||||
def check_for_psutil():
|
||||
try:
|
||||
spec = importlib.util.find_spec('psutil')
|
||||
return spec is not None
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
|
||||
invokeAI_mps_available = check_for_psutil()
|
||||
|
||||
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
|
||||
if invokeAI_mps_available:
|
||||
import psutil
|
||||
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
|
||||
def einsum_op_compvis(q, k, v):
|
||||
s = einsum('b i d, b j d -> b i j', q, k)
|
||||
s = s.softmax(dim=-1, dtype=s.dtype)
|
||||
return einsum('b i j, b j d -> b i d', s, v)
|
||||
|
||||
def einsum_op_slice_0(q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
|
||||
return r
|
||||
|
||||
def einsum_op_slice_1(q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(q, k, v):
|
||||
if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
|
||||
return einsum_op_compvis(q, k, v)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
if slice_size % 4096 == 0:
|
||||
slice_size -= 1
|
||||
return einsum_op_slice_1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(q, k, v):
|
||||
if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
|
||||
return einsum_op_compvis(q, k, v)
|
||||
else:
|
||||
return einsum_op_slice_0(q, k, v, 1)
|
||||
|
||||
def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
return einsum_op_compvis(q, k, v)
|
||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
if div <= q.shape[0]:
|
||||
return einsum_op_slice_0(q, k, v, q.shape[0] // div)
|
||||
return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
def einsum_op_cuda(q, k, v):
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
def einsum_op(q, k, v):
|
||||
if q.device.type == 'cuda':
|
||||
return einsum_op_cuda(q, k, v)
|
||||
|
||||
if q.device.type == 'mps':
|
||||
if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
|
||||
return einsum_op_mps_v1(q, k, v)
|
||||
return einsum_op_mps_v2(q, k, v)
|
||||
|
||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||
k = self.to_k(context_k) * self.scale
|
||||
v = self.to_v(context_v)
|
||||
del context, context_k, context_v, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
r = einsum_op(q, k, v)
|
||||
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
||||
|
||||
# -- End of code from https://github.com/invoke-ai/InvokeAI --
|
||||
|
||||
def xformers_attention_forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||
|
||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
def cross_attention_attnblock_forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q1 = self.q(h_)
|
||||
k1 = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q1.shape
|
||||
|
||||
q2 = q1.reshape(b, c, h*w)
|
||||
del q1
|
||||
|
||||
q = q2.permute(0, 2, 1) # b,hw,c
|
||||
del q2
|
||||
|
||||
k = k1.reshape(b, c, h*w) # b,c,hw
|
||||
del k1
|
||||
|
||||
h_ = torch.zeros_like(k, device=q.device)
|
||||
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||
mem_required = tensor_size * 2.5
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
|
||||
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w2 = w1 * (int(c)**(-0.5))
|
||||
del w1
|
||||
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
|
||||
del w2
|
||||
|
||||
# attend to values
|
||||
v1 = v.reshape(b, c, h*w)
|
||||
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
del w3
|
||||
|
||||
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
del v1, w4
|
||||
|
||||
h2 = h_.reshape(b, c, h, w)
|
||||
del h_
|
||||
|
||||
h3 = self.proj_out(h2)
|
||||
del h2
|
||||
|
||||
h3 += x
|
||||
|
||||
return h3
|
||||
|
||||
def xformers_attnblock_forward(self, x):
|
||||
try:
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
||||
out = self.proj_out(out)
|
||||
return x + out
|
||||
except NotImplementedError:
|
||||
return cross_attention_attnblock_forward(self, x)
|
||||
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
|
||||
|
||||
class TorchHijackForUnet:
|
||||
"""
|
||||
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
||||
this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
|
||||
"""
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'cat':
|
||||
return self.cat
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
||||
|
||||
def cat(self, tensors, *args, **kwargs):
|
||||
if len(tensors) == 2:
|
||||
a, b = tensors
|
||||
if a.shape[-2:] != b.shape[-2:]:
|
||||
a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
|
||||
|
||||
tensors = (a, b)
|
||||
|
||||
return torch.cat(tensors, *args, **kwargs)
|
||||
|
||||
|
||||
th = TorchHijackForUnet()
|
||||
@ -0,0 +1,34 @@
|
||||
import open_clip.tokenizer
|
||||
import torch
|
||||
|
||||
from modules import sd_hijack_clip, devices
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
|
||||
self.id_start = wrapped.config.bos_token_id
|
||||
self.id_end = wrapped.config.eos_token_id
|
||||
self.id_pad = wrapped.config.pad_token_id
|
||||
|
||||
self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have </w> bits for comma
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
# there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
|
||||
# trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
|
||||
# layer to work with - you have to use the last
|
||||
|
||||
attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
|
||||
features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
|
||||
z = features['projection_state']
|
||||
|
||||
return z
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
embedding_layer = self.wrapped.roberta.embeddings
|
||||
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
||||
|
||||
return embedded
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue