Merge branch 'master' of https://github.com/AUTOMATIC1111/stable-diffusion-webui into upstream-master
commit
6fdad291bd
@ -0,0 +1,73 @@
|
||||
import os.path
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from multiprocessing import get_context
|
||||
|
||||
|
||||
def _load_tf_and_return_tags(pil_image, threshold):
|
||||
import deepdanbooru as dd
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
this_folder = os.path.dirname(__file__)
|
||||
model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
|
||||
if not os.path.exists(os.path.join(model_path, 'project.json')):
|
||||
# there is no point importing these every time
|
||||
import zipfile
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
load_file_from_url(r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
|
||||
model_path)
|
||||
with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
|
||||
zip_ref.extractall(model_path)
|
||||
os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
|
||||
|
||||
tags = dd.project.load_tags_from_project(model_path)
|
||||
model = dd.project.load_model_from_project(
|
||||
model_path, compile_model=True
|
||||
)
|
||||
|
||||
width = model.input_shape[2]
|
||||
height = model.input_shape[1]
|
||||
image = np.array(pil_image)
|
||||
image = tf.image.resize(
|
||||
image,
|
||||
size=(height, width),
|
||||
method=tf.image.ResizeMethod.AREA,
|
||||
preserve_aspect_ratio=True,
|
||||
)
|
||||
image = image.numpy() # EagerTensor to np.array
|
||||
image = dd.image.transform_and_pad_image(image, width, height)
|
||||
image = image / 255.0
|
||||
image_shape = image.shape
|
||||
image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
|
||||
|
||||
y = model.predict(image)[0]
|
||||
|
||||
result_dict = {}
|
||||
|
||||
for i, tag in enumerate(tags):
|
||||
result_dict[tag] = y[i]
|
||||
result_tags_out = []
|
||||
result_tags_print = []
|
||||
for tag in tags:
|
||||
if result_dict[tag] >= threshold:
|
||||
if tag.startswith("rating:"):
|
||||
continue
|
||||
result_tags_out.append(tag)
|
||||
result_tags_print.append(f'{result_dict[tag]} {tag}')
|
||||
|
||||
print('\n'.join(sorted(result_tags_print, reverse=True)))
|
||||
|
||||
return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ')
|
||||
|
||||
|
||||
def subprocess_init_no_cuda():
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||
|
||||
|
||||
def get_deepbooru_tags(pil_image, threshold=0.5):
|
||||
context = get_context('spawn')
|
||||
with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor:
|
||||
f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, )
|
||||
ret = f.result() # will rethrow any exceptions
|
||||
return ret
|
||||
@ -0,0 +1,89 @@
|
||||
# this code is adapted from the script contributed by anon from /h/
|
||||
|
||||
import io
|
||||
import pickle
|
||||
import collections
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
import numpy
|
||||
import _codecs
|
||||
import zipfile
|
||||
|
||||
|
||||
def encode(*args):
|
||||
out = _codecs.encode(*args)
|
||||
return out
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
def persistent_load(self, saved_id):
|
||||
assert saved_id[0] == 'storage'
|
||||
return torch.storage._TypedStorage()
|
||||
|
||||
def find_class(self, module, name):
|
||||
if module == 'collections' and name == 'OrderedDict':
|
||||
return getattr(collections, name)
|
||||
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
|
||||
return getattr(torch._utils, name)
|
||||
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']:
|
||||
return getattr(torch, name)
|
||||
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
||||
return getattr(torch.nn.modules.container, name)
|
||||
if module == 'numpy.core.multiarray' and name == 'scalar':
|
||||
return numpy.core.multiarray.scalar
|
||||
if module == 'numpy' and name == 'dtype':
|
||||
return numpy.dtype
|
||||
if module == '_codecs' and name == 'encode':
|
||||
return encode
|
||||
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
|
||||
import pytorch_lightning.callbacks
|
||||
return pytorch_lightning.callbacks.model_checkpoint
|
||||
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
|
||||
import pytorch_lightning.callbacks.model_checkpoint
|
||||
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
|
||||
if module == "__builtin__" and name == 'set':
|
||||
return set
|
||||
|
||||
# Forbid everything else.
|
||||
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
|
||||
|
||||
|
||||
def check_pt(filename):
|
||||
try:
|
||||
|
||||
# new pytorch format is a zip file
|
||||
with zipfile.ZipFile(filename) as z:
|
||||
with z.open('archive/data.pkl') as file:
|
||||
unpickler = RestrictedUnpickler(file)
|
||||
unpickler.load()
|
||||
|
||||
except zipfile.BadZipfile:
|
||||
|
||||
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
|
||||
with open(filename, "rb") as file:
|
||||
unpickler = RestrictedUnpickler(file)
|
||||
for i in range(5):
|
||||
unpickler.load()
|
||||
|
||||
|
||||
def load(filename, *args, **kwargs):
|
||||
from modules import shared
|
||||
|
||||
try:
|
||||
if not shared.cmd_opts.disable_safe_unpickle:
|
||||
check_pt(filename)
|
||||
|
||||
except Exception:
|
||||
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
|
||||
print(f"You can skip this check with --disable-safe-unpickle commandline argument.", file=sys.stderr)
|
||||
return None
|
||||
|
||||
return unsafe_torch_load(filename, *args, **kwargs)
|
||||
|
||||
|
||||
unsafe_torch_load = torch.load
|
||||
torch.load = load
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 526 KiB After Width: | Height: | Size: 329 KiB |
Loading…
Reference in New Issue