|
|
|
|
@ -118,6 +118,12 @@ class PersonalizedBase(Dataset):
|
|
|
|
|
self.gradient_step = min(gradient_step, self.length // self.batch_size)
|
|
|
|
|
self.latent_sampling_method = latent_sampling_method
|
|
|
|
|
|
|
|
|
|
if len(groups) > 1:
|
|
|
|
|
print("Buckets:")
|
|
|
|
|
for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
|
|
|
|
|
print(f" {w}x{h}: {len(ids)}")
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
def create_text(self, filename_text):
|
|
|
|
|
text = random.choice(self.lines)
|
|
|
|
|
tags = filename_text.split(',')
|
|
|
|
|
@ -140,8 +146,11 @@ class PersonalizedBase(Dataset):
|
|
|
|
|
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
|
|
|
|
|
return entry
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GroupedBatchSampler(Sampler):
|
|
|
|
|
def __init__(self, data_source: PersonalizedBase, batch_size: int):
|
|
|
|
|
super().__init__(data_source)
|
|
|
|
|
|
|
|
|
|
n = len(data_source)
|
|
|
|
|
self.groups = data_source.groups
|
|
|
|
|
self.len = n_batch = n // batch_size
|
|
|
|
|
@ -150,21 +159,28 @@ class GroupedBatchSampler(Sampler):
|
|
|
|
|
self.n_rand_batches = nrb = n_batch - sum(self.base)
|
|
|
|
|
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
|
|
|
|
|
self.batch_size = batch_size
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return self.len
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
b = self.batch_size
|
|
|
|
|
|
|
|
|
|
for g in self.groups:
|
|
|
|
|
shuffle(g)
|
|
|
|
|
|
|
|
|
|
batches = []
|
|
|
|
|
for g in self.groups:
|
|
|
|
|
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
|
|
|
|
|
for _ in range(self.n_rand_batches):
|
|
|
|
|
rand_group = choices(self.groups, self.probs)[0]
|
|
|
|
|
batches.append(choices(rand_group, k=b))
|
|
|
|
|
|
|
|
|
|
shuffle(batches)
|
|
|
|
|
|
|
|
|
|
yield from batches
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PersonalizedDataLoader(DataLoader):
|
|
|
|
|
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
|
|
|
|
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
|
|
|
|
|
|