|
|
|
|
@ -9,6 +9,9 @@ from torchvision import transforms
|
|
|
|
|
import random
|
|
|
|
|
import tqdm
|
|
|
|
|
from modules import devices
|
|
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PersonalizedBase(Dataset):
|
|
|
|
|
@ -38,8 +41,8 @@ class PersonalizedBase(Dataset):
|
|
|
|
|
image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
|
|
|
|
|
|
|
|
|
|
filename = os.path.basename(path)
|
|
|
|
|
filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-')
|
|
|
|
|
filename_tokens = [token for token in filename_tokens if token.isalpha()]
|
|
|
|
|
filename_tokens = os.path.splitext(filename)[0]
|
|
|
|
|
filename_tokens = re_tag.findall(filename_tokens)
|
|
|
|
|
|
|
|
|
|
npimage = np.array(image).astype(np.uint8)
|
|
|
|
|
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
|
|
|
|
|