Merge branch 'AUTOMATIC1111:master' into kr-localization
commit
660ae690bd
@ -1,32 +0,0 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: bug-report
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**To Reproduce**
|
||||
Steps to reproduce the behavior:
|
||||
1. Go to '...'
|
||||
2. Click on '....'
|
||||
3. Scroll down to '....'
|
||||
4. See error
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Screenshots**
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
**Desktop (please complete the following information):**
|
||||
- OS: [e.g. Windows, Linux]
|
||||
- Browser [e.g. chrome, safari]
|
||||
- Commit revision [looks like this: e68484500f76a33ba477d5a99340ab30451e557b; can be seen when launching webui.bat, or obtained manually by running `git rev-parse HEAD`]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
@ -0,0 +1,83 @@
|
||||
name: Bug Report
|
||||
description: You think somethings is broken in the UI
|
||||
title: "[Bug]: "
|
||||
labels: ["bug-report"]
|
||||
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Is there an existing issue for this?
|
||||
description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
|
||||
options:
|
||||
- label: I have searched the existing issues and checked the recent builds/commits
|
||||
required: true
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
*Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
|
||||
- type: textarea
|
||||
id: what-did
|
||||
attributes:
|
||||
label: What happened?
|
||||
description: Tell us what happened in a very clear and simple way
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: steps
|
||||
attributes:
|
||||
label: Steps to reproduce the problem
|
||||
description: Please provide us with precise step by step information on how to reproduce the bug
|
||||
value: |
|
||||
1. Go to ....
|
||||
2. Press ....
|
||||
3. ...
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: what-should
|
||||
attributes:
|
||||
label: What should have happened?
|
||||
description: tell what you think the normal behavior should be
|
||||
validations:
|
||||
required: true
|
||||
- type: input
|
||||
id: commit
|
||||
attributes:
|
||||
label: Commit where the problem happens
|
||||
description: Which commit are you running ? (copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
id: platforms
|
||||
attributes:
|
||||
label: What platforms do you use to access UI ?
|
||||
multiple: true
|
||||
options:
|
||||
- Windows
|
||||
- Linux
|
||||
- MacOS
|
||||
- iOS
|
||||
- Android
|
||||
- Other/Cloud
|
||||
- type: dropdown
|
||||
id: browsers
|
||||
attributes:
|
||||
label: What browsers do you use to access the UI ?
|
||||
multiple: true
|
||||
options:
|
||||
- Mozilla Firefox
|
||||
- Google Chrome
|
||||
- Brave
|
||||
- Apple Safari
|
||||
- Microsoft Edge
|
||||
- type: textarea
|
||||
id: cmdargs
|
||||
attributes:
|
||||
label: Command Line Arguments
|
||||
description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below
|
||||
render: Shell
|
||||
- type: textarea
|
||||
id: misc
|
||||
attributes:
|
||||
label: Additional information, context and logs
|
||||
description: Please provide us with any relevant additional info, context or log output.
|
||||
@ -0,0 +1,5 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: WebUI Community Support
|
||||
url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions
|
||||
about: Please ask and answer questions here.
|
||||
@ -1,20 +0,0 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: 'suggestion'
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
||||
@ -0,0 +1,40 @@
|
||||
name: Feature request
|
||||
description: Suggest an idea for this project
|
||||
title: "[Feature Request]: "
|
||||
labels: ["suggestion"]
|
||||
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Is there an existing issue for this?
|
||||
description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit.
|
||||
options:
|
||||
- label: I have searched the existing issues and checked the recent builds/commits
|
||||
required: true
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
*Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible*
|
||||
- type: textarea
|
||||
id: feature
|
||||
attributes:
|
||||
label: What would your feature do ?
|
||||
description: Tell us about your feature in a very clear and simple way, and what problem it would solve
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: workflow
|
||||
attributes:
|
||||
label: Proposed workflow
|
||||
description: Please provide us with step by step information on how you'd like the feature to be accessed and used
|
||||
value: |
|
||||
1. Go to ....
|
||||
2. Press ....
|
||||
3. ...
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: misc
|
||||
attributes:
|
||||
label: Additional information
|
||||
description: Add any other context or screenshots about the feature request here.
|
||||
|
@ -0,0 +1,124 @@
|
||||
from modules.api.processing import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.sd_samplers import all_samplers
|
||||
from modules.extras import run_pnginfo
|
||||
import modules.shared as shared
|
||||
import uvicorn
|
||||
from fastapi import Body, APIRouter, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field, Json
|
||||
import json
|
||||
import io
|
||||
import base64
|
||||
from PIL import Image
|
||||
|
||||
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
|
||||
|
||||
class TextToImageResponse(BaseModel):
|
||||
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: Json
|
||||
info: Json
|
||||
|
||||
class ImageToImageResponse(BaseModel):
|
||||
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: Json
|
||||
info: Json
|
||||
|
||||
|
||||
class Api:
|
||||
def __init__(self, app, queue_lock):
|
||||
self.router = APIRouter()
|
||||
self.app = app
|
||||
self.queue_lock = queue_lock
|
||||
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
|
||||
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
|
||||
|
||||
def __base64_to_image(self, base64_string):
|
||||
# if has a comma, deal with prefix
|
||||
if "," in base64_string:
|
||||
base64_string = base64_string.split(",")[1]
|
||||
imgdata = base64.b64decode(base64_string)
|
||||
# convert base64 to PIL image
|
||||
return Image.open(io.BytesIO(imgdata))
|
||||
|
||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
||||
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
|
||||
|
||||
if sampler_index is None:
|
||||
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||
|
||||
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||
"sd_model": shared.sd_model,
|
||||
"sampler_index": sampler_index[0],
|
||||
"do_not_save_samples": True,
|
||||
"do_not_save_grid": True
|
||||
}
|
||||
)
|
||||
p = StableDiffusionProcessingTxt2Img(**vars(populate))
|
||||
# Override object param
|
||||
with self.queue_lock:
|
||||
processed = process_images(p)
|
||||
|
||||
b64images = []
|
||||
for i in processed.images:
|
||||
buffer = io.BytesIO()
|
||||
i.save(buffer, format="png")
|
||||
b64images.append(base64.b64encode(buffer.getvalue()))
|
||||
|
||||
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info))
|
||||
|
||||
|
||||
|
||||
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
|
||||
sampler_index = sampler_to_index(img2imgreq.sampler_index)
|
||||
|
||||
if sampler_index is None:
|
||||
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||
|
||||
|
||||
init_images = img2imgreq.init_images
|
||||
if init_images is None:
|
||||
raise HTTPException(status_code=404, detail="Init image not found")
|
||||
|
||||
mask = img2imgreq.mask
|
||||
if mask:
|
||||
mask = self.__base64_to_image(mask)
|
||||
|
||||
|
||||
populate = img2imgreq.copy(update={ # Override __init__ params
|
||||
"sd_model": shared.sd_model,
|
||||
"sampler_index": sampler_index[0],
|
||||
"do_not_save_samples": True,
|
||||
"do_not_save_grid": True,
|
||||
"mask": mask
|
||||
}
|
||||
)
|
||||
p = StableDiffusionProcessingImg2Img(**vars(populate))
|
||||
|
||||
imgs = []
|
||||
for img in init_images:
|
||||
img = self.__base64_to_image(img)
|
||||
imgs = [img] * p.batch_size
|
||||
|
||||
p.init_images = imgs
|
||||
# Override object param
|
||||
with self.queue_lock:
|
||||
processed = process_images(p)
|
||||
|
||||
b64images = []
|
||||
for i in processed.images:
|
||||
buffer = io.BytesIO()
|
||||
i.save(buffer, format="png")
|
||||
b64images.append(base64.b64encode(buffer.getvalue()))
|
||||
|
||||
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info))
|
||||
|
||||
def extrasapi(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def pnginfoapi(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def launch(self, server_name, port):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(self.app, host=server_name, port=port)
|
||||
@ -0,0 +1,106 @@
|
||||
from array import array
|
||||
from inflection import underscore
|
||||
from typing import Any, Dict, Optional
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
||||
import inspect
|
||||
|
||||
|
||||
API_NOT_ALLOWED = [
|
||||
"self",
|
||||
"kwargs",
|
||||
"sd_model",
|
||||
"outpath_samples",
|
||||
"outpath_grids",
|
||||
"sampler_index",
|
||||
"do_not_save_samples",
|
||||
"do_not_save_grid",
|
||||
"extra_generation_params",
|
||||
"overlay_images",
|
||||
"do_not_reload_embeddings",
|
||||
"seed_enable_extras",
|
||||
"prompt_for_display",
|
||||
"sampler_noise_scheduler_override",
|
||||
"ddim_discretize"
|
||||
]
|
||||
|
||||
class ModelDef(BaseModel):
|
||||
"""Assistance Class for Pydantic Dynamic Model Generation"""
|
||||
|
||||
field: str
|
||||
field_alias: str
|
||||
field_type: Any
|
||||
field_value: Any
|
||||
|
||||
|
||||
class PydanticModelGenerator:
|
||||
"""
|
||||
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
|
||||
source_data is a snapshot of the default values produced by the class
|
||||
params are the names of the actual keys required by __init__
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = None,
|
||||
class_instance = None,
|
||||
additional_fields = None,
|
||||
):
|
||||
def field_type_generator(k, v):
|
||||
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
||||
# print(k, v.annotation, v.default)
|
||||
field_type = v.annotation
|
||||
|
||||
return Optional[field_type]
|
||||
|
||||
def merge_class_params(class_):
|
||||
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
|
||||
parameters = {}
|
||||
for classes in all_classes:
|
||||
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
||||
return parameters
|
||||
|
||||
|
||||
self._model_name = model_name
|
||||
self._class_data = merge_class_params(class_instance)
|
||||
self._model_def = [
|
||||
ModelDef(
|
||||
field=underscore(k),
|
||||
field_alias=k,
|
||||
field_type=field_type_generator(k, v),
|
||||
field_value=v.default
|
||||
)
|
||||
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||
]
|
||||
|
||||
for fields in additional_fields:
|
||||
self._model_def.append(ModelDef(
|
||||
field=underscore(fields["key"]),
|
||||
field_alias=fields["key"],
|
||||
field_type=fields["type"],
|
||||
field_value=fields["default"]))
|
||||
|
||||
def generate_model(self):
|
||||
"""
|
||||
Creates a pydantic BaseModel
|
||||
from the json and overrides provided at initialization
|
||||
"""
|
||||
fields = {
|
||||
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
|
||||
}
|
||||
DynamicModel = create_model(self._model_name, **fields)
|
||||
DynamicModel.__config__.allow_population_by_field_name = True
|
||||
DynamicModel.__config__.allow_mutation = True
|
||||
return DynamicModel
|
||||
|
||||
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingTxt2Img",
|
||||
StableDiffusionProcessingTxt2Img,
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}]
|
||||
).generate_model()
|
||||
|
||||
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingImg2Img",
|
||||
StableDiffusionProcessingImg2Img,
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}]
|
||||
).generate_model()
|
||||
@ -1,76 +0,0 @@
|
||||
import os.path
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
import torch
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.upscaler
|
||||
from modules import devices, modelloader
|
||||
from modules.bsrgan_model_arch import RRDBNet
|
||||
|
||||
|
||||
class UpscalerBSRGAN(modules.upscaler.Upscaler):
|
||||
def __init__(self, dirname):
|
||||
self.name = "BSRGAN"
|
||||
self.model_name = "BSRGAN 4x"
|
||||
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
|
||||
self.user_path = dirname
|
||||
super().__init__()
|
||||
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
||||
scalers = []
|
||||
if len(model_paths) == 0:
|
||||
scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
|
||||
scalers.append(scaler_data)
|
||||
for file in model_paths:
|
||||
if "http" in file:
|
||||
name = self.model_name
|
||||
else:
|
||||
name = modelloader.friendly_name(file)
|
||||
try:
|
||||
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
||||
scalers.append(scaler_data)
|
||||
except Exception:
|
||||
print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
self.scalers = scalers
|
||||
|
||||
def do_upscale(self, img: PIL.Image, selected_file):
|
||||
torch.cuda.empty_cache()
|
||||
model = self.load_model(selected_file)
|
||||
if model is None:
|
||||
return img
|
||||
model.to(devices.device_bsrgan)
|
||||
torch.cuda.empty_cache()
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.moveaxis(img, 2, 0) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(devices.device_bsrgan)
|
||||
with torch.no_grad():
|
||||
output = model(img)
|
||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output = 255. * np.moveaxis(output, 0, 2)
|
||||
output = output.astype(np.uint8)
|
||||
output = output[:, :, ::-1]
|
||||
torch.cuda.empty_cache()
|
||||
return PIL.Image.fromarray(output, 'RGB')
|
||||
|
||||
def load_model(self, path: str):
|
||||
if "http" in path:
|
||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
||||
progress=True)
|
||||
else:
|
||||
filename = path
|
||||
if not os.path.exists(filename) or filename is None:
|
||||
print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
|
||||
return None
|
||||
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
|
||||
model.load_state_dict(torch.load(filename), strict=True)
|
||||
model.eval()
|
||||
for k, v in model.named_parameters():
|
||||
v.requires_grad = False
|
||||
return model
|
||||
|
||||
@ -1,102 +0,0 @@
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
|
||||
|
||||
def initialize_weights(net_l, scale=1):
|
||||
if not isinstance(net_l, list):
|
||||
net_l = [net_l]
|
||||
for net in net_l:
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
m.weight.data *= scale # for residual block
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
m.weight.data *= scale
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
|
||||
def make_layer(block, n_layers):
|
||||
layers = []
|
||||
for _ in range(n_layers):
|
||||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
# initialization
|
||||
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
'''Residual in Residual Dense Block'''
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
self.sf = sf
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
if self.sf==4:
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
fea = fea + trunk
|
||||
|
||||
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
if self.sf==4:
|
||||
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
return out
|
||||
@ -1,80 +1,463 @@
|
||||
# this file is taken from https://github.com/xinntao/ESRGAN
|
||||
# this file is adapted from https://github.com/victorca25/iNNfer
|
||||
|
||||
import math
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def make_layer(block, n_layers):
|
||||
layers = []
|
||||
for _ in range(n_layers):
|
||||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
####################
|
||||
# RRDBNet Generator
|
||||
####################
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
|
||||
act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
|
||||
finalact=None, gaussian_noise=False, plus=False):
|
||||
super(RRDBNet, self).__init__()
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
if upscale == 3:
|
||||
n_upscale = 1
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
self.resrgan_scale = 0
|
||||
if in_nc % 16 == 0:
|
||||
self.resrgan_scale = 1
|
||||
elif in_nc != 4 and in_nc % 4 == 0:
|
||||
self.resrgan_scale = 2
|
||||
|
||||
# initialization
|
||||
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
||||
rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||
norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
|
||||
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
|
||||
LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
if upsample_mode == 'upconv':
|
||||
upsample_block = upconv_block
|
||||
elif upsample_mode == 'pixelshuffle':
|
||||
upsample_block = pixelshuffle_block
|
||||
else:
|
||||
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
||||
if upscale == 3:
|
||||
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
||||
else:
|
||||
upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
|
||||
HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
|
||||
HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
||||
|
||||
outact = act(finalact) if finalact else None
|
||||
|
||||
self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
|
||||
*upsampler, HR_conv0, HR_conv1, outact)
|
||||
|
||||
def forward(self, x, outm=None):
|
||||
if self.resrgan_scale == 1:
|
||||
feat = pixel_unshuffle(x, scale=4)
|
||||
elif self.resrgan_scale == 2:
|
||||
feat = pixel_unshuffle(x, scale=2)
|
||||
else:
|
||||
feat = x
|
||||
|
||||
return self.model(feat)
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
'''Residual in Residual Dense Block'''
|
||||
"""
|
||||
Residual in Residual Dense Block
|
||||
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
||||
"""
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False, gaussian_noise=False, plus=False):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
# This is for backwards compatibility with existing models
|
||||
if nr == 3:
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus)
|
||||
else:
|
||||
RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
|
||||
self.RDBs = nn.Sequential(*RDB_list)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
if hasattr(self, 'RDB1'):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
else:
|
||||
out = self.RDBs(x)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
"""
|
||||
Residual Dense Block
|
||||
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
||||
Modified options that can be used:
|
||||
- "Partial Convolution based Padding" arXiv:1811.11718
|
||||
- "Spectral normalization" arXiv:1802.05957
|
||||
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
||||
{Rakotonirina} and A. {Rasoanaivo}
|
||||
"""
|
||||
|
||||
def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False, gaussian_noise=False, plus=False):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
self.noise = GaussianNoise() if gaussian_noise else None
|
||||
self.conv1x1 = conv1x1(nf, gc) if plus else None
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
if mode == 'CNA':
|
||||
last_act = None
|
||||
else:
|
||||
last_act = act_type
|
||||
self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
fea = fea + trunk
|
||||
x1 = self.conv1(x)
|
||||
x2 = self.conv2(torch.cat((x, x1), 1))
|
||||
if self.conv1x1:
|
||||
x2 = x2 + self.conv1x1(x)
|
||||
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
||||
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
||||
if self.conv1x1:
|
||||
x4 = x4 + x2
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
if self.noise:
|
||||
return self.noise(x5.mul(0.2) + x)
|
||||
else:
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
####################
|
||||
# ESRGANplus
|
||||
####################
|
||||
|
||||
class GaussianNoise(nn.Module):
|
||||
def __init__(self, sigma=0.1, is_relative_detach=False):
|
||||
super().__init__()
|
||||
self.sigma = sigma
|
||||
self.is_relative_detach = is_relative_detach
|
||||
self.noise = torch.tensor(0, dtype=torch.float)
|
||||
|
||||
def forward(self, x):
|
||||
if self.training and self.sigma != 0:
|
||||
self.noise = self.noise.to(x.device)
|
||||
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
||||
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
||||
x = x + sampled_noise
|
||||
return x
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
####################
|
||||
# SRVGGNetCompact
|
||||
####################
|
||||
|
||||
class SRVGGNetCompact(nn.Module):
|
||||
"""A compact VGG-style network structure for super-resolution.
|
||||
This class is copied from https://github.com/xinntao/Real-ESRGAN
|
||||
"""
|
||||
|
||||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
||||
super(SRVGGNetCompact, self).__init__()
|
||||
self.num_in_ch = num_in_ch
|
||||
self.num_out_ch = num_out_ch
|
||||
self.num_feat = num_feat
|
||||
self.num_conv = num_conv
|
||||
self.upscale = upscale
|
||||
self.act_type = act_type
|
||||
|
||||
self.body = nn.ModuleList()
|
||||
# the first conv
|
||||
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
||||
# the first activation
|
||||
if act_type == 'relu':
|
||||
activation = nn.ReLU(inplace=True)
|
||||
elif act_type == 'prelu':
|
||||
activation = nn.PReLU(num_parameters=num_feat)
|
||||
elif act_type == 'leakyrelu':
|
||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.body.append(activation)
|
||||
|
||||
# the body structure
|
||||
for _ in range(num_conv):
|
||||
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
||||
# activation
|
||||
if act_type == 'relu':
|
||||
activation = nn.ReLU(inplace=True)
|
||||
elif act_type == 'prelu':
|
||||
activation = nn.PReLU(num_parameters=num_feat)
|
||||
elif act_type == 'leakyrelu':
|
||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.body.append(activation)
|
||||
|
||||
# the last conv
|
||||
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
||||
# upsample
|
||||
self.upsampler = nn.PixelShuffle(upscale)
|
||||
|
||||
def forward(self, x):
|
||||
out = x
|
||||
for i in range(0, len(self.body)):
|
||||
out = self.body[i](out)
|
||||
|
||||
out = self.upsampler(out)
|
||||
# add the nearest upsampled image, so that the network learns the residual
|
||||
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
||||
out += base
|
||||
return out
|
||||
|
||||
|
||||
####################
|
||||
# Upsampler
|
||||
####################
|
||||
|
||||
class Upsample(nn.Module):
|
||||
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
|
||||
The input data is assumed to be of the form
|
||||
`minibatch x channels x [optional depth] x [optional height] x width`.
|
||||
"""
|
||||
|
||||
def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||
super(Upsample, self).__init__()
|
||||
if isinstance(scale_factor, tuple):
|
||||
self.scale_factor = tuple(float(factor) for factor in scale_factor)
|
||||
else:
|
||||
self.scale_factor = float(scale_factor) if scale_factor else None
|
||||
self.mode = mode
|
||||
self.size = size
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
||||
|
||||
def extra_repr(self):
|
||||
if self.scale_factor is not None:
|
||||
info = 'scale_factor=' + str(self.scale_factor)
|
||||
else:
|
||||
info = 'size=' + str(self.size)
|
||||
info += ', mode=' + self.mode
|
||||
return info
|
||||
|
||||
|
||||
def pixel_unshuffle(x, scale):
|
||||
""" Pixel unshuffle.
|
||||
Args:
|
||||
x (Tensor): Input feature with shape (b, c, hh, hw).
|
||||
scale (int): Downsample ratio.
|
||||
Returns:
|
||||
Tensor: the pixel unshuffled feature.
|
||||
"""
|
||||
b, c, hh, hw = x.size()
|
||||
out_channel = c * (scale**2)
|
||||
assert hh % scale == 0 and hw % scale == 0
|
||||
h = hh // scale
|
||||
w = hw // scale
|
||||
x_view = x.view(b, c, h, scale, w, scale)
|
||||
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
||||
|
||||
|
||||
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
||||
pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
|
||||
"""
|
||||
Pixel shuffle layer
|
||||
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
||||
Neural Network, CVPR17)
|
||||
"""
|
||||
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
|
||||
pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
|
||||
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
a = act(act_type) if act_type else None
|
||||
return sequential(conv, pixel_shuffle, n, a)
|
||||
|
||||
|
||||
def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
||||
pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
|
||||
""" Upconv layer """
|
||||
upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
|
||||
upsample = Upsample(scale_factor=upscale_factor, mode=mode)
|
||||
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
|
||||
pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
|
||||
return sequential(upsample, conv)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
####################
|
||||
# Basic blocks
|
||||
####################
|
||||
|
||||
|
||||
def make_layer(basic_block, num_basic_block, **kwarg):
|
||||
"""Make layers by stacking the same blocks.
|
||||
Args:
|
||||
basic_block (nn.module): nn.module class for basic block. (block)
|
||||
num_basic_block (int): number of blocks. (n_layers)
|
||||
Returns:
|
||||
nn.Sequential: Stacked blocks in nn.Sequential.
|
||||
"""
|
||||
layers = []
|
||||
for _ in range(num_basic_block):
|
||||
layers.append(basic_block(**kwarg))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
|
||||
""" activation helper """
|
||||
act_type = act_type.lower()
|
||||
if act_type == 'relu':
|
||||
layer = nn.ReLU(inplace)
|
||||
elif act_type in ('leakyrelu', 'lrelu'):
|
||||
layer = nn.LeakyReLU(neg_slope, inplace)
|
||||
elif act_type == 'prelu':
|
||||
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
||||
elif act_type == 'tanh': # [-1, 1] range output
|
||||
layer = nn.Tanh()
|
||||
elif act_type == 'sigmoid': # [0, 1] range output
|
||||
layer = nn.Sigmoid()
|
||||
else:
|
||||
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
|
||||
return layer
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self, *kwargs):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x, *kwargs):
|
||||
return x
|
||||
|
||||
|
||||
def norm(norm_type, nc):
|
||||
""" Return a normalization layer """
|
||||
norm_type = norm_type.lower()
|
||||
if norm_type == 'batch':
|
||||
layer = nn.BatchNorm2d(nc, affine=True)
|
||||
elif norm_type == 'instance':
|
||||
layer = nn.InstanceNorm2d(nc, affine=False)
|
||||
elif norm_type == 'none':
|
||||
def norm_layer(x): return Identity()
|
||||
else:
|
||||
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
|
||||
return layer
|
||||
|
||||
|
||||
def pad(pad_type, padding):
|
||||
""" padding layer helper """
|
||||
pad_type = pad_type.lower()
|
||||
if padding == 0:
|
||||
return None
|
||||
if pad_type == 'reflect':
|
||||
layer = nn.ReflectionPad2d(padding)
|
||||
elif pad_type == 'replicate':
|
||||
layer = nn.ReplicationPad2d(padding)
|
||||
elif pad_type == 'zero':
|
||||
layer = nn.ZeroPad2d(padding)
|
||||
else:
|
||||
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
|
||||
return layer
|
||||
|
||||
|
||||
def get_valid_padding(kernel_size, dilation):
|
||||
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
||||
padding = (kernel_size - 1) // 2
|
||||
return padding
|
||||
|
||||
|
||||
class ShortcutBlock(nn.Module):
|
||||
""" Elementwise sum the output of a submodule to its input """
|
||||
def __init__(self, submodule):
|
||||
super(ShortcutBlock, self).__init__()
|
||||
self.sub = submodule
|
||||
|
||||
def forward(self, x):
|
||||
output = x + self.sub(x)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
|
||||
|
||||
|
||||
def sequential(*args):
|
||||
""" Flatten Sequential. It unwraps nn.Sequential. """
|
||||
if len(args) == 1:
|
||||
if isinstance(args[0], OrderedDict):
|
||||
raise NotImplementedError('sequential does not support OrderedDict input.')
|
||||
return args[0] # No sequential is needed.
|
||||
modules = []
|
||||
for module in args:
|
||||
if isinstance(module, nn.Sequential):
|
||||
for submodule in module.children():
|
||||
modules.append(submodule)
|
||||
elif isinstance(module, nn.Module):
|
||||
modules.append(module)
|
||||
return nn.Sequential(*modules)
|
||||
|
||||
|
||||
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
|
||||
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False):
|
||||
""" Conv layer with padding, normalization, activation """
|
||||
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
|
||||
padding = get_valid_padding(kernel_size, dilation)
|
||||
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
||||
padding = padding if pad_type == 'zero' else 0
|
||||
|
||||
if convtype=='PartialConv2D':
|
||||
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
elif convtype=='DeformConv2D':
|
||||
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
elif convtype=='Conv3D':
|
||||
c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
else:
|
||||
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
|
||||
if spectral_norm:
|
||||
c = nn.utils.spectral_norm(c)
|
||||
|
||||
a = act(act_type) if act_type else None
|
||||
if 'CNA' in mode:
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
return sequential(p, c, n, a)
|
||||
elif mode == 'NAC':
|
||||
if norm_type is None and act_type is not None:
|
||||
a = act(act_type, inplace=False)
|
||||
n = norm(norm_type, in_nc) if norm_type else None
|
||||
return sequential(n, a, p, c)
|
||||
|
||||
@ -1,183 +1,424 @@
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
import hashlib
|
||||
import gradio
|
||||
system_bak_path = "webui_log_and_bak"
|
||||
custom_tab_name = "custom fold"
|
||||
faverate_tab_name = "favorites"
|
||||
tabs_list = ["txt2img", "img2img", "extras", faverate_tab_name]
|
||||
def is_valid_date(date):
|
||||
try:
|
||||
time.strptime(date, "%Y%m%d")
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def reduplicative_file_move(src, dst):
|
||||
def same_name_file(basename, path):
|
||||
name, ext = os.path.splitext(basename)
|
||||
f_list = os.listdir(path)
|
||||
max_num = 0
|
||||
for f in f_list:
|
||||
if len(f) <= len(basename):
|
||||
continue
|
||||
f_ext = f[-len(ext):] if len(ext) > 0 else ""
|
||||
if f[:len(name)] == name and f_ext == ext:
|
||||
if f[len(name)] == "(" and f[-len(ext)-1] == ")":
|
||||
number = f[len(name)+1:-len(ext)-1]
|
||||
if number.isdigit():
|
||||
if int(number) > max_num:
|
||||
max_num = int(number)
|
||||
return f"{name}({max_num + 1}){ext}"
|
||||
name = os.path.basename(src)
|
||||
save_name = os.path.join(dst, name)
|
||||
if not os.path.exists(save_name):
|
||||
shutil.move(src, dst)
|
||||
else:
|
||||
name = same_name_file(name, dst)
|
||||
shutil.move(src, os.path.join(dst, name))
|
||||
|
||||
def traverse_all_files(output_dir, image_list, curr_dir=None):
|
||||
curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
|
||||
def traverse_all_files(curr_path, image_list, all_type=False):
|
||||
try:
|
||||
f_list = os.listdir(curr_path)
|
||||
except:
|
||||
if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt":
|
||||
image_list.append(curr_dir)
|
||||
if all_type or (curr_path[-10:].rfind(".") > 0 and curr_path[-4:] != ".txt" and curr_path[-4:] != ".csv"):
|
||||
image_list.append(curr_path)
|
||||
return image_list
|
||||
for file in f_list:
|
||||
file = file if curr_dir is None else os.path.join(curr_dir, file)
|
||||
file_path = os.path.join(curr_path, file)
|
||||
if file[-4:] == ".txt":
|
||||
file = os.path.join(curr_path, file)
|
||||
if (not all_type) and (file[-4:] == ".txt" or file[-4:] == ".csv"):
|
||||
pass
|
||||
elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0:
|
||||
elif os.path.isfile(file) and file[-10:].rfind(".") > 0:
|
||||
image_list.append(file)
|
||||
else:
|
||||
image_list = traverse_all_files(output_dir, image_list, file)
|
||||
image_list = traverse_all_files(file, image_list)
|
||||
return image_list
|
||||
|
||||
def auto_sorting(dir_name):
|
||||
bak_path = os.path.join(dir_name, system_bak_path)
|
||||
if not os.path.exists(bak_path):
|
||||
os.mkdir(bak_path)
|
||||
log_file = None
|
||||
files_list = []
|
||||
f_list = os.listdir(dir_name)
|
||||
for file in f_list:
|
||||
if file == system_bak_path:
|
||||
continue
|
||||
file_path = os.path.join(dir_name, file)
|
||||
if not is_valid_date(file):
|
||||
if file[-10:].rfind(".") > 0:
|
||||
files_list.append(file_path)
|
||||
else:
|
||||
files_list = traverse_all_files(file_path, files_list, all_type=True)
|
||||
|
||||
def get_recent_images(dir_name, page_index, step, image_index, tabname):
|
||||
page_index = int(page_index)
|
||||
image_list = []
|
||||
if not os.path.exists(dir_name):
|
||||
pass
|
||||
elif os.path.isdir(dir_name):
|
||||
image_list = traverse_all_files(dir_name, image_list)
|
||||
image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file)))
|
||||
else:
|
||||
print(f'ERROR: "{dir_name}" is not a directory. Check the path in the settings.', file=sys.stderr)
|
||||
num = 48 if tabname != "extras" else 12
|
||||
max_page_index = len(image_list) // num + 1
|
||||
page_index = max_page_index if page_index == -1 else page_index + step
|
||||
page_index = 1 if page_index < 1 else page_index
|
||||
page_index = max_page_index if page_index > max_page_index else page_index
|
||||
idx_frm = (page_index - 1) * num
|
||||
image_list = image_list[idx_frm:idx_frm + num]
|
||||
image_index = int(image_index)
|
||||
if image_index < 0 or image_index > len(image_list) - 1:
|
||||
current_file = None
|
||||
hidden = None
|
||||
else:
|
||||
current_file = image_list[int(image_index)]
|
||||
hidden = os.path.join(dir_name, current_file)
|
||||
return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, ""
|
||||
|
||||
|
||||
def first_page_click(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, 1, 0, image_index, tabname)
|
||||
|
||||
|
||||
def end_page_click(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, -1, 0, image_index, tabname)
|
||||
|
||||
|
||||
def prev_page_click(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, page_index, -1, image_index, tabname)
|
||||
|
||||
|
||||
def next_page_click(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, page_index, 1, image_index, tabname)
|
||||
|
||||
|
||||
def page_index_change(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, page_index, 0, image_index, tabname)
|
||||
|
||||
|
||||
def show_image_info(num, image_path, filenames):
|
||||
# print(f"select image {num}")
|
||||
file = filenames[int(num)]
|
||||
return file, num, os.path.join(image_path, file)
|
||||
for file in files_list:
|
||||
date_str = time.strftime("%Y%m%d",time.localtime(os.path.getmtime(file)))
|
||||
file_path = os.path.dirname(file)
|
||||
hash_path = hashlib.md5(file_path.encode()).hexdigest()
|
||||
path = os.path.join(dir_name, date_str, hash_path)
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
if log_file is None:
|
||||
log_file = open(os.path.join(bak_path,"path_mapping.csv"),"a")
|
||||
log_file.write(f"{hash_path},{file_path}\n")
|
||||
reduplicative_file_move(file, path)
|
||||
|
||||
date_list = []
|
||||
f_list = os.listdir(dir_name)
|
||||
for f in f_list:
|
||||
if is_valid_date(f):
|
||||
date_list.append(f)
|
||||
elif f == system_bak_path:
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
reduplicative_file_move(os.path.join(dir_name, f), bak_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
today = time.strftime("%Y%m%d",time.localtime(time.time()))
|
||||
if today not in date_list:
|
||||
date_list.append(today)
|
||||
return sorted(date_list, reverse=True)
|
||||
|
||||
def archive_images(dir_name, date_to):
|
||||
filenames = []
|
||||
batch_size =int(opts.images_history_num_per_page * opts.images_history_pages_num)
|
||||
if batch_size <= 0:
|
||||
batch_size = opts.images_history_num_per_page * 6
|
||||
today = time.strftime("%Y%m%d",time.localtime(time.time()))
|
||||
date_to = today if date_to is None or date_to == "" else date_to
|
||||
date_to_bak = date_to
|
||||
if False: #opts.images_history_reconstruct_directory:
|
||||
date_list = auto_sorting(dir_name)
|
||||
for date in date_list:
|
||||
if date <= date_to:
|
||||
path = os.path.join(dir_name, date)
|
||||
if date == today and not os.path.exists(path):
|
||||
continue
|
||||
filenames = traverse_all_files(path, filenames)
|
||||
if len(filenames) > batch_size:
|
||||
break
|
||||
filenames = sorted(filenames, key=lambda file: -os.path.getmtime(file))
|
||||
else:
|
||||
filenames = traverse_all_files(dir_name, filenames)
|
||||
total_num = len(filenames)
|
||||
tmparray = [(os.path.getmtime(file), file) for file in filenames ]
|
||||
date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400
|
||||
filenames = []
|
||||
date_list = {date_to:None}
|
||||
date = time.strftime("%Y%m%d",time.localtime(time.time()))
|
||||
for t, f in tmparray:
|
||||
date = time.strftime("%Y%m%d",time.localtime(t))
|
||||
date_list[date] = None
|
||||
if t <= date_stamp:
|
||||
filenames.append((t, f ,date))
|
||||
date_list = sorted(list(date_list.keys()), reverse=True)
|
||||
sort_array = sorted(filenames, key=lambda x:-x[0])
|
||||
if len(sort_array) > batch_size:
|
||||
date = sort_array[batch_size][2]
|
||||
filenames = [x[1] for x in sort_array]
|
||||
else:
|
||||
date = date_to if len(sort_array) == 0 else sort_array[-1][2]
|
||||
filenames = [x[1] for x in sort_array]
|
||||
filenames = [x[1] for x in sort_array if x[2]>= date]
|
||||
num = len(filenames)
|
||||
last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000))
|
||||
date = date[:4] + "/" + date[4:6] + "/" + date[6:8]
|
||||
date_to_bak = date_to_bak[:4] + "/" + date_to_bak[4:6] + "/" + date_to_bak[6:8]
|
||||
load_info = "<div style='color:#999' align='center'>"
|
||||
load_info += f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages"
|
||||
load_info += "</div>"
|
||||
_, image_list, _, _, visible_num = get_recent_images(1, 0, filenames)
|
||||
return (
|
||||
date_to,
|
||||
load_info,
|
||||
filenames,
|
||||
1,
|
||||
image_list,
|
||||
"",
|
||||
"",
|
||||
visible_num,
|
||||
last_date_from,
|
||||
gradio.update(visible=total_num > num)
|
||||
)
|
||||
|
||||
def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index):
|
||||
def delete_image(delete_num, name, filenames, image_index, visible_num):
|
||||
if name == "":
|
||||
return filenames, delete_num
|
||||
else:
|
||||
delete_num = int(delete_num)
|
||||
visible_num = int(visible_num)
|
||||
image_index = int(image_index)
|
||||
index = list(filenames).index(name)
|
||||
i = 0
|
||||
new_file_list = []
|
||||
for name in filenames:
|
||||
if i >= index and i < index + delete_num:
|
||||
path = os.path.join(dir_name, name)
|
||||
if os.path.exists(path):
|
||||
print(f"Delete file {path}")
|
||||
os.remove(path)
|
||||
txt_file = os.path.splitext(path)[0] + ".txt"
|
||||
if os.path.exists(name):
|
||||
if visible_num == image_index:
|
||||
new_file_list.append(name)
|
||||
i += 1
|
||||
continue
|
||||
print(f"Delete file {name}")
|
||||
os.remove(name)
|
||||
visible_num -= 1
|
||||
txt_file = os.path.splitext(name)[0] + ".txt"
|
||||
if os.path.exists(txt_file):
|
||||
os.remove(txt_file)
|
||||
else:
|
||||
print(f"Not exists file {path}")
|
||||
print(f"Not exists file {name}")
|
||||
else:
|
||||
new_file_list.append(name)
|
||||
i += 1
|
||||
return new_file_list, 1
|
||||
return new_file_list, 1, visible_num
|
||||
|
||||
def save_image(file_name):
|
||||
if file_name is not None and os.path.exists(file_name):
|
||||
shutil.copy(file_name, opts.outdir_save)
|
||||
|
||||
def get_recent_images(page_index, step, filenames):
|
||||
page_index = int(page_index)
|
||||
num_of_imgs_per_page = int(opts.images_history_num_per_page)
|
||||
max_page_index = len(filenames) // num_of_imgs_per_page + 1
|
||||
page_index = max_page_index if page_index == -1 else page_index + step
|
||||
page_index = 1 if page_index < 1 else page_index
|
||||
page_index = max_page_index if page_index > max_page_index else page_index
|
||||
idx_frm = (page_index - 1) * num_of_imgs_per_page
|
||||
image_list = filenames[idx_frm:idx_frm + num_of_imgs_per_page]
|
||||
length = len(filenames)
|
||||
visible_num = num_of_imgs_per_page if idx_frm + num_of_imgs_per_page <= length else length % num_of_imgs_per_page
|
||||
visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num
|
||||
return page_index, image_list, "", "", visible_num
|
||||
|
||||
def loac_batch_click(date_to):
|
||||
if date_to is None:
|
||||
return time.strftime("%Y%m%d",time.localtime(time.time())), []
|
||||
else:
|
||||
return None, []
|
||||
def forward_click(last_date_from, date_to_recorder):
|
||||
if len(date_to_recorder) == 0:
|
||||
return None, []
|
||||
if last_date_from == date_to_recorder[-1]:
|
||||
date_to_recorder = date_to_recorder[:-1]
|
||||
if len(date_to_recorder) == 0:
|
||||
return None, []
|
||||
return date_to_recorder[-1], date_to_recorder[:-1]
|
||||
|
||||
def backward_click(last_date_from, date_to_recorder):
|
||||
if last_date_from is None or last_date_from == "":
|
||||
return time.strftime("%Y%m%d",time.localtime(time.time())), []
|
||||
if len(date_to_recorder) == 0 or last_date_from != date_to_recorder[-1]:
|
||||
date_to_recorder.append(last_date_from)
|
||||
return last_date_from, date_to_recorder
|
||||
|
||||
|
||||
def first_page_click(page_index, filenames):
|
||||
return get_recent_images(1, 0, filenames)
|
||||
|
||||
def end_page_click(page_index, filenames):
|
||||
return get_recent_images(-1, 0, filenames)
|
||||
|
||||
def prev_page_click(page_index, filenames):
|
||||
return get_recent_images(page_index, -1, filenames)
|
||||
|
||||
def next_page_click(page_index, filenames):
|
||||
return get_recent_images(page_index, 1, filenames)
|
||||
|
||||
def page_index_change(page_index, filenames):
|
||||
return get_recent_images(page_index, 0, filenames)
|
||||
|
||||
def show_image_info(tabname_box, num, page_index, filenames):
|
||||
file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))]
|
||||
tm = "<div style='color:#999' align='right'>" + time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file))) + "</div>"
|
||||
return file, tm, num, file
|
||||
|
||||
def enable_page_buttons():
|
||||
return gradio.update(visible=True)
|
||||
|
||||
def change_dir(img_dir, date_to):
|
||||
warning = None
|
||||
try:
|
||||
if os.path.exists(img_dir):
|
||||
try:
|
||||
f = os.listdir(img_dir)
|
||||
except:
|
||||
warning = f"'{img_dir} is not a directory"
|
||||
else:
|
||||
warning = "The directory is not exist"
|
||||
except:
|
||||
warning = "The format of the directory is incorrect"
|
||||
if warning is None:
|
||||
today = time.strftime("%Y%m%d",time.localtime(time.time()))
|
||||
return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today, gradio.update(visible=True), gradio.update(visible=True)
|
||||
else:
|
||||
return gradio.update(visible=True), gradio.update(visible=False), warning, date_to, gradio.update(visible=False), gradio.update(visible=False)
|
||||
|
||||
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
|
||||
if opts.outdir_samples != "":
|
||||
dir_name = opts.outdir_samples
|
||||
elif tabname == "txt2img":
|
||||
custom_dir = False
|
||||
if tabname == "txt2img":
|
||||
dir_name = opts.outdir_txt2img_samples
|
||||
elif tabname == "img2img":
|
||||
dir_name = opts.outdir_img2img_samples
|
||||
elif tabname == "extras":
|
||||
dir_name = opts.outdir_extras_samples
|
||||
elif tabname == faverate_tab_name:
|
||||
dir_name = opts.outdir_save
|
||||
else:
|
||||
return
|
||||
with gr.Row():
|
||||
renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page")
|
||||
first_page = gr.Button('First Page')
|
||||
prev_page = gr.Button('Prev Page')
|
||||
page_index = gr.Number(value=1, label="Page Index")
|
||||
next_page = gr.Button('Next Page')
|
||||
end_page = gr.Button('End Page')
|
||||
with gr.Row(elem_id=tabname + "_images_history"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
|
||||
with gr.Row():
|
||||
delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
|
||||
delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
|
||||
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
img_file_info = gr.Textbox(label="Generate Info", interactive=False)
|
||||
img_file_name = gr.Textbox(label="File Name", interactive=False)
|
||||
with gr.Row():
|
||||
custom_dir = True
|
||||
dir_name = None
|
||||
|
||||
if not custom_dir:
|
||||
d = dir_name.split("/")
|
||||
dir_name = d[0]
|
||||
for p in d[1:]:
|
||||
dir_name = os.path.join(dir_name, p)
|
||||
if not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name)
|
||||
|
||||
with gr.Column() as page_panel:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, visible=not custom_dir) as load_batch_box:
|
||||
load_batch = gr.Button('Load', elem_id=tabname + "_images_history_start", full_width=True)
|
||||
with gr.Column(scale=4):
|
||||
with gr.Row():
|
||||
img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir)
|
||||
with gr.Row():
|
||||
with gr.Column(visible=False, scale=1) as batch_panel:
|
||||
with gr.Row():
|
||||
forward = gr.Button('Prev batch')
|
||||
backward = gr.Button('Next batch')
|
||||
with gr.Column(scale=3):
|
||||
load_info = gr.HTML(visible=not custom_dir)
|
||||
with gr.Row(visible=False) as warning:
|
||||
warning_box = gr.Textbox("Message", interactive=False)
|
||||
|
||||
with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel:
|
||||
with gr.Column(scale=2):
|
||||
with gr.Row(visible=True) as turn_page_buttons:
|
||||
#date_to = gr.Dropdown(label="Date to")
|
||||
first_page = gr.Button('First Page')
|
||||
prev_page = gr.Button('Prev Page')
|
||||
page_index = gr.Number(value=1, label="Page Index")
|
||||
next_page = gr.Button('Next Page')
|
||||
end_page = gr.Button('End Page')
|
||||
|
||||
history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=opts.images_history_grid_num)
|
||||
with gr.Row():
|
||||
delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
|
||||
delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
img_file_info = gr.Textbox(label="Generate Info", interactive=False, lines=6)
|
||||
gr.HTML("<hr>")
|
||||
img_file_name = gr.Textbox(value="", label="File Name", interactive=False)
|
||||
img_file_time= gr.HTML()
|
||||
with gr.Row():
|
||||
if tabname != faverate_tab_name:
|
||||
save_btn = gr.Button('Collect')
|
||||
pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
|
||||
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
||||
|
||||
|
||||
# hiden items
|
||||
with gr.Row(visible=False):
|
||||
renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page")
|
||||
batch_date_to = gr.Textbox(label="Date to")
|
||||
visible_img_num = gr.Number()
|
||||
date_to_recorder = gr.State([])
|
||||
last_date_from = gr.Textbox()
|
||||
tabname_box = gr.Textbox(tabname)
|
||||
image_index = gr.Textbox(value=-1)
|
||||
set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index")
|
||||
filenames = gr.State()
|
||||
all_images_list = gr.State()
|
||||
hidden = gr.Image(type="pil")
|
||||
info1 = gr.Textbox()
|
||||
info2 = gr.Textbox()
|
||||
|
||||
img_path = gr.Textbox(dir_name.rstrip("/"), visible=False)
|
||||
tabname_box = gr.Textbox(tabname, visible=False)
|
||||
image_index = gr.Textbox(value=-1, visible=False)
|
||||
set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False)
|
||||
filenames = gr.State()
|
||||
hidden = gr.Image(type="pil", visible=False)
|
||||
info1 = gr.Textbox(visible=False)
|
||||
info2 = gr.Textbox(visible=False)
|
||||
|
||||
# turn pages
|
||||
gallery_inputs = [img_path, page_index, image_index, tabname_box]
|
||||
gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name]
|
||||
|
||||
first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
# page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index])
|
||||
img_path.submit(change_dir, inputs=[img_path, batch_date_to], outputs=[warning, main_panel, warning_box, batch_date_to, load_batch_box, load_info])
|
||||
|
||||
#change batch
|
||||
change_date_output = [batch_date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from, batch_panel]
|
||||
|
||||
batch_date_to.change(archive_images, inputs=[img_path, batch_date_to], outputs=change_date_output)
|
||||
batch_date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons])
|
||||
batch_date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
|
||||
|
||||
load_batch.click(loac_batch_click, inputs=[batch_date_to], outputs=[batch_date_to, date_to_recorder])
|
||||
forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder])
|
||||
backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder])
|
||||
|
||||
|
||||
#delete
|
||||
delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num])
|
||||
delete.click(fn=None, _js="images_history_delete", inputs=[delete_num, tabname_box, image_index], outputs=None)
|
||||
if tabname != faverate_tab_name:
|
||||
save_btn.click(save_image, inputs=[img_file_name], outputs=None)
|
||||
|
||||
#turn page
|
||||
gallery_inputs = [page_index, filenames]
|
||||
gallery_outputs = [page_index, history_gallery, img_file_name, img_file_time, visible_img_num]
|
||||
first_page.click(first_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
next_page.click(next_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
end_page.click(end_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
page_index.submit(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
renew_page.click(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
|
||||
first_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
|
||||
next_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
|
||||
prev_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
|
||||
end_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
|
||||
page_index.submit(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
|
||||
renew_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
|
||||
|
||||
# other funcitons
|
||||
set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden])
|
||||
img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
|
||||
delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
|
||||
set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, img_file_time, image_index, hidden])
|
||||
img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
|
||||
hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
|
||||
|
||||
# pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
|
||||
switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
|
||||
switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
|
||||
|
||||
|
||||
|
||||
def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
|
||||
def create_history_tabs(gr, sys_opts, cmp_ops, run_pnginfo, switch_dict):
|
||||
global opts;
|
||||
opts = sys_opts
|
||||
loads_files_num = int(opts.images_history_num_per_page)
|
||||
num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num)
|
||||
if cmp_ops.browse_all_images:
|
||||
tabs_list.append(custom_tab_name)
|
||||
with gr.Blocks(analytics_enabled=False) as images_history:
|
||||
with gr.Tabs() as tabs:
|
||||
with gr.Tab("txt2img history"):
|
||||
with gr.Blocks(analytics_enabled=False) as images_history_txt2img:
|
||||
show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict)
|
||||
with gr.Tab("img2img history"):
|
||||
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
|
||||
show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict)
|
||||
with gr.Tab("extras history"):
|
||||
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
|
||||
show_images_history(gr, opts, "extras", run_pnginfo, switch_dict)
|
||||
for tab in tabs_list:
|
||||
with gr.Tab(tab):
|
||||
with gr.Blocks(analytics_enabled=False) :
|
||||
show_images_history(gr, opts, tab, run_pnginfo, switch_dict)
|
||||
gradio.Checkbox(opts.images_history_preload, elem_id="images_history_preload", visible=False)
|
||||
gradio.Textbox(",".join(tabs_list), elem_id="images_history_tabnames_list", visible=False)
|
||||
|
||||
return images_history
|
||||
|
||||
@ -0,0 +1,53 @@
|
||||
|
||||
callbacks_model_loaded = []
|
||||
callbacks_ui_tabs = []
|
||||
callbacks_ui_settings = []
|
||||
|
||||
|
||||
def clear_callbacks():
|
||||
callbacks_model_loaded.clear()
|
||||
callbacks_ui_tabs.clear()
|
||||
|
||||
|
||||
def model_loaded_callback(sd_model):
|
||||
for callback in callbacks_model_loaded:
|
||||
callback(sd_model)
|
||||
|
||||
|
||||
def ui_tabs_callback():
|
||||
res = []
|
||||
|
||||
for callback in callbacks_ui_tabs:
|
||||
res += callback() or []
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def ui_settings_callback():
|
||||
for callback in callbacks_ui_settings:
|
||||
callback()
|
||||
|
||||
|
||||
def on_model_loaded(callback):
|
||||
"""register a function to be called when the stable diffusion model is created; the model is
|
||||
passed as an argument"""
|
||||
callbacks_model_loaded.append(callback)
|
||||
|
||||
|
||||
def on_ui_tabs(callback):
|
||||
"""register a function to be called when the UI is creating new tabs.
|
||||
The function must either return a None, which means no new tabs to be added, or a list, where
|
||||
each element is a tuple:
|
||||
(gradio_component, title, elem_id)
|
||||
|
||||
gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
|
||||
title is tab text displayed to user in the UI
|
||||
elem_id is HTML id for the tab
|
||||
"""
|
||||
callbacks_ui_tabs.append(callback)
|
||||
|
||||
|
||||
def on_ui_settings(callback):
|
||||
"""register a function to be called before UI settings are populated; add your settings
|
||||
by using shared.opts.add_option(shared.OptionInfo(...)) """
|
||||
callbacks_ui_settings.append(callback)
|
||||
@ -0,0 +1,331 @@
|
||||
import torch
|
||||
|
||||
from einops import repeat
|
||||
from omegaconf import ListConfig
|
||||
|
||||
import ldm.models.diffusion.ddpm
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
|
||||
|
||||
# =================================================================================================
|
||||
# Monkey patch DDIMSampler methods from RunwayML repo directly.
|
||||
# Adapted from:
|
||||
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
|
||||
# =================================================================================================
|
||||
@torch.no_grad()
|
||||
def sample_ddim(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||
|
||||
samples, intermediates = self.ddim_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [
|
||||
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||
for i in range(len(c[k]))
|
||||
]
|
||||
else:
|
||||
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
|
||||
# =================================================================================================
|
||||
# Monkey patch PLMSSampler methods.
|
||||
# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes.
|
||||
# Adapted from:
|
||||
# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py
|
||||
# =================================================================================================
|
||||
@torch.no_grad()
|
||||
def sample_plms(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for PLMS sampling is {size}')
|
||||
|
||||
samples, intermediates = self.plms_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [
|
||||
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||
for i in range(len(c[k]))
|
||||
]
|
||||
else:
|
||||
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
e_t = get_model_output(x, t)
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = get_model_output(x_prev, t_next)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
return x_prev, pred_x0, e_t
|
||||
|
||||
# =================================================================================================
|
||||
# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
|
||||
# Adapted from:
|
||||
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py
|
||||
# =================================================================================================
|
||||
|
||||
@torch.no_grad()
|
||||
def get_unconditional_conditioning(self, batch_size, null_label=None):
|
||||
if null_label is not None:
|
||||
xc = null_label
|
||||
if isinstance(xc, ListConfig):
|
||||
xc = list(xc)
|
||||
if isinstance(xc, dict) or isinstance(xc, list):
|
||||
c = self.get_learned_conditioning(xc)
|
||||
else:
|
||||
if hasattr(xc, "to"):
|
||||
xc = xc.to(self.device)
|
||||
c = self.get_learned_conditioning(xc)
|
||||
else:
|
||||
# todo: get null label from cond_stage_model
|
||||
raise NotImplementedError()
|
||||
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
|
||||
return c
|
||||
|
||||
|
||||
class LatentInpaintDiffusion(LatentDiffusion):
|
||||
def __init__(
|
||||
self,
|
||||
concat_keys=("mask", "masked_image"),
|
||||
masked_image_key="masked_image",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.masked_image_key = masked_image_key
|
||||
assert self.masked_image_key in concat_keys
|
||||
self.concat_keys = concat_keys
|
||||
|
||||
|
||||
def should_hijack_inpainting(checkpoint_info):
|
||||
return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")
|
||||
|
||||
|
||||
def do_inpainting_hijack():
|
||||
ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
|
||||
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
|
||||
|
||||
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
|
||||
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
|
||||
|
||||
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
||||
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
|
||||
Loading…
Reference in New Issue