|
|
|
|
@ -17,7 +17,7 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DatasetEntry:
|
|
|
|
|
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
|
|
|
|
|
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, img_shape=None):
|
|
|
|
|
self.filename = filename
|
|
|
|
|
self.filename_text = filename_text
|
|
|
|
|
self.latent_dist = latent_dist
|
|
|
|
|
@ -25,16 +25,15 @@ class DatasetEntry:
|
|
|
|
|
self.cond = cond
|
|
|
|
|
self.cond_text = cond_text
|
|
|
|
|
self.pixel_values = pixel_values
|
|
|
|
|
self.img_shape = img_shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PersonalizedBase(Dataset):
|
|
|
|
|
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
|
|
|
|
|
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False):
|
|
|
|
|
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
|
|
|
|
|
|
|
|
|
|
self.placeholder_token = placeholder_token
|
|
|
|
|
|
|
|
|
|
self.width = width
|
|
|
|
|
self.height = height
|
|
|
|
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
|
|
|
|
|
|
|
|
|
self.dataset = []
|
|
|
|
|
@ -47,6 +46,8 @@ class PersonalizedBase(Dataset):
|
|
|
|
|
assert data_root, 'dataset directory not specified'
|
|
|
|
|
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
|
|
|
|
assert os.listdir(data_root), "Dataset directory is empty"
|
|
|
|
|
if varsize:
|
|
|
|
|
assert batch_size == 1, 'variable img size must have batch size 1'
|
|
|
|
|
|
|
|
|
|
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
|
|
|
|
|
|
|
|
|
@ -59,7 +60,9 @@ class PersonalizedBase(Dataset):
|
|
|
|
|
if shared.state.interrupted:
|
|
|
|
|
raise Exception("interrupted")
|
|
|
|
|
try:
|
|
|
|
|
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
|
|
|
|
|
image = Image.open(path).convert('RGB')
|
|
|
|
|
if not varsize:
|
|
|
|
|
image = image.resize((width, height), PIL.Image.BICUBIC)
|
|
|
|
|
except Exception:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
@ -88,14 +91,14 @@ class PersonalizedBase(Dataset):
|
|
|
|
|
if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
|
|
|
|
|
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
|
|
|
|
latent_sampling_method = "once"
|
|
|
|
|
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
|
|
|
|
|
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size)
|
|
|
|
|
elif latent_sampling_method == "deterministic":
|
|
|
|
|
# Works only for DiagonalGaussianDistribution
|
|
|
|
|
latent_dist.std = 0
|
|
|
|
|
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
|
|
|
|
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
|
|
|
|
|
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size)
|
|
|
|
|
elif latent_sampling_method == "random":
|
|
|
|
|
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
|
|
|
|
|
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, img_shape=image.size)
|
|
|
|
|
|
|
|
|
|
if not (self.tag_drop_out != 0 or self.shuffle_tags):
|
|
|
|
|
entry.cond_text = self.create_text(filename_text)
|
|
|
|
|
@ -151,6 +154,7 @@ class BatchLoader:
|
|
|
|
|
self.cond_text = [entry.cond_text for entry in data]
|
|
|
|
|
self.cond = [entry.cond for entry in data]
|
|
|
|
|
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
|
|
|
|
|
self.img_shape = [entry.img_shape for entry in data]
|
|
|
|
|
#self.emb_index = [entry.emb_index for entry in data]
|
|
|
|
|
#print(self.latent_sample.device)
|
|
|
|
|
|
|
|
|
|
|