|
|
|
|
@ -7,124 +7,11 @@ import tqdm
|
|
|
|
|
import html
|
|
|
|
|
import datetime
|
|
|
|
|
|
|
|
|
|
from PIL import Image,PngImagePlugin,ImageDraw
|
|
|
|
|
from ..images import captionImageOverlay
|
|
|
|
|
import numpy as np
|
|
|
|
|
import base64
|
|
|
|
|
import json
|
|
|
|
|
import zlib
|
|
|
|
|
from PIL import Image,PngImagePlugin
|
|
|
|
|
|
|
|
|
|
from modules import shared, devices, sd_hijack, processing, sd_models
|
|
|
|
|
import modules.textual_inversion.dataset
|
|
|
|
|
|
|
|
|
|
class EmbeddingEncoder(json.JSONEncoder):
|
|
|
|
|
def default(self, obj):
|
|
|
|
|
if isinstance(obj, torch.Tensor):
|
|
|
|
|
return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()}
|
|
|
|
|
return json.JSONEncoder.default(self, obj)
|
|
|
|
|
|
|
|
|
|
class EmbeddingDecoder(json.JSONDecoder):
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
|
|
|
|
|
def object_hook(self, d):
|
|
|
|
|
if 'TORCHTENSOR' in d:
|
|
|
|
|
return torch.from_numpy(np.array(d['TORCHTENSOR']))
|
|
|
|
|
return d
|
|
|
|
|
|
|
|
|
|
def embeddingToB64(data):
|
|
|
|
|
d = json.dumps(data,cls=EmbeddingEncoder)
|
|
|
|
|
return base64.b64encode(d.encode())
|
|
|
|
|
|
|
|
|
|
def embeddingFromB64(data):
|
|
|
|
|
d = base64.b64decode(data)
|
|
|
|
|
return json.loads(d,cls=EmbeddingDecoder)
|
|
|
|
|
|
|
|
|
|
def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
|
|
|
|
|
while True:
|
|
|
|
|
seed = (a * seed + c) % m
|
|
|
|
|
yield seed
|
|
|
|
|
|
|
|
|
|
def xorBlock(block):
|
|
|
|
|
g = lcg()
|
|
|
|
|
randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
|
|
|
|
|
return np.bitwise_xor(block.astype(np.uint8),randblock & 0x0F)
|
|
|
|
|
|
|
|
|
|
def styleBlock(block,sequence):
|
|
|
|
|
im = Image.new('RGB',(block.shape[1],block.shape[0]))
|
|
|
|
|
draw = ImageDraw.Draw(im)
|
|
|
|
|
i=0
|
|
|
|
|
for x in range(-6,im.size[0],8):
|
|
|
|
|
for yi,y in enumerate(range(-6,im.size[1],8)):
|
|
|
|
|
offset=0
|
|
|
|
|
if yi%2==0:
|
|
|
|
|
offset=4
|
|
|
|
|
shade = sequence[i%len(sequence)]
|
|
|
|
|
i+=1
|
|
|
|
|
draw.ellipse((x+offset, y, x+6+offset, y+6), fill =(shade,shade,shade) )
|
|
|
|
|
|
|
|
|
|
fg = np.array(im).astype(np.uint8) & 0xF0
|
|
|
|
|
return block ^ fg
|
|
|
|
|
|
|
|
|
|
def insertImageDataEmbed(image,data):
|
|
|
|
|
d = 3
|
|
|
|
|
data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9)
|
|
|
|
|
dnp = np.frombuffer(data_compressed,np.uint8).copy()
|
|
|
|
|
dnphigh = dnp >> 4
|
|
|
|
|
dnplow = dnp & 0x0F
|
|
|
|
|
|
|
|
|
|
h = image.size[1]
|
|
|
|
|
next_size = dnplow.shape[0] + (h-(dnplow.shape[0]%h))
|
|
|
|
|
next_size = next_size + ((h*d)-(next_size%(h*d)))
|
|
|
|
|
|
|
|
|
|
dnplow.resize(next_size)
|
|
|
|
|
dnplow = dnplow.reshape((h,-1,d))
|
|
|
|
|
|
|
|
|
|
dnphigh.resize(next_size)
|
|
|
|
|
dnphigh = dnphigh.reshape((h,-1,d))
|
|
|
|
|
|
|
|
|
|
edgeStyleWeights = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
|
|
|
|
|
edgeStyleWeights = (np.abs(edgeStyleWeights)/np.max(np.abs(edgeStyleWeights))*255).astype(np.uint8)
|
|
|
|
|
|
|
|
|
|
dnplow = styleBlock(dnplow,sequence=edgeStyleWeights)
|
|
|
|
|
dnplow = xorBlock(dnplow)
|
|
|
|
|
dnphigh = styleBlock(dnphigh,sequence=edgeStyleWeights[::-1])
|
|
|
|
|
dnphigh = xorBlock(dnphigh)
|
|
|
|
|
|
|
|
|
|
imlow = Image.fromarray(dnplow,mode='RGB')
|
|
|
|
|
imhigh = Image.fromarray(dnphigh,mode='RGB')
|
|
|
|
|
|
|
|
|
|
background = Image.new('RGB',(image.size[0]+imlow.size[0]+imhigh.size[0]+2,image.size[1]),(0,0,0))
|
|
|
|
|
background.paste(imlow,(0,0))
|
|
|
|
|
background.paste(image,(imlow.size[0]+1,0))
|
|
|
|
|
background.paste(imhigh,(imlow.size[0]+1+image.size[0]+1,0))
|
|
|
|
|
|
|
|
|
|
return background
|
|
|
|
|
|
|
|
|
|
def crop_black(img,tol=0):
|
|
|
|
|
mask = (img>tol).all(2)
|
|
|
|
|
mask0,mask1 = mask.any(0),mask.any(1)
|
|
|
|
|
col_start,col_end = mask0.argmax(),mask.shape[1]-mask0[::-1].argmax()
|
|
|
|
|
row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax()
|
|
|
|
|
return img[row_start:row_end,col_start:col_end]
|
|
|
|
|
|
|
|
|
|
def extractImageDataEmbed(image):
|
|
|
|
|
d=3
|
|
|
|
|
outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) & 0x0F
|
|
|
|
|
blackCols = np.where( np.sum(outarr, axis=(0,2))==0)
|
|
|
|
|
if blackCols[0].shape[0] < 2:
|
|
|
|
|
print('No Image data blocks found.')
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
dataBlocklower = outarr[:,:blackCols[0].min(),:].astype(np.uint8)
|
|
|
|
|
dataBlockupper = outarr[:,blackCols[0].max()+1:,:].astype(np.uint8)
|
|
|
|
|
|
|
|
|
|
dataBlocklower = xorBlock(dataBlocklower)
|
|
|
|
|
dataBlockupper = xorBlock(dataBlockupper)
|
|
|
|
|
|
|
|
|
|
dataBlock = (dataBlockupper << 4) | (dataBlocklower)
|
|
|
|
|
dataBlock = dataBlock.flatten().tobytes()
|
|
|
|
|
data = zlib.decompress(dataBlock)
|
|
|
|
|
return json.loads(data,cls=EmbeddingDecoder)
|
|
|
|
|
|
|
|
|
|
class Embedding:
|
|
|
|
|
def __init__(self, vec, name, step=None):
|
|
|
|
|
@ -199,10 +86,10 @@ class EmbeddingDatabase:
|
|
|
|
|
if filename.upper().endswith('.PNG'):
|
|
|
|
|
embed_image = Image.open(path)
|
|
|
|
|
if 'sd-ti-embedding' in embed_image.text:
|
|
|
|
|
data = embeddingFromB64(embed_image.text['sd-ti-embedding'])
|
|
|
|
|
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
|
|
|
|
name = data.get('name',name)
|
|
|
|
|
else:
|
|
|
|
|
data = extractImageDataEmbed(embed_image)
|
|
|
|
|
data = extract_image_data_embed(embed_image)
|
|
|
|
|
name = data.get('name',name)
|
|
|
|
|
else:
|
|
|
|
|
data = torch.load(path, map_location="cpu")
|
|
|
|
|
@ -393,7 +280,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
|
|
|
|
|
|
|
|
|
info = PngImagePlugin.PngInfo()
|
|
|
|
|
data = torch.load(last_saved_file)
|
|
|
|
|
info.add_text("sd-ti-embedding", embeddingToB64(data))
|
|
|
|
|
info.add_text("sd-ti-embedding", embedding_to_b64(data))
|
|
|
|
|
|
|
|
|
|
title = "<{}>".format(data.get('name','???'))
|
|
|
|
|
checkpoint = sd_models.select_checkpoint()
|
|
|
|
|
@ -401,8 +288,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
|
|
|
|
footer_mid = '[{}]'.format(checkpoint.hash)
|
|
|
|
|
footer_right = '{}'.format(embedding.step)
|
|
|
|
|
|
|
|
|
|
captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right)
|
|
|
|
|
captioned_image = insertImageDataEmbed(captioned_image,data)
|
|
|
|
|
captioned_image = caption_image_overlay(image,title,footer_left,footer_mid,footer_right)
|
|
|
|
|
captioned_image = insert_image_data_embed(captioned_image,data)
|
|
|
|
|
|
|
|
|
|
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
|
|
|
|
|
|
|
|
|
|