|
|
|
|
@ -98,12 +98,12 @@ class PersonalizedBase(Dataset):
|
|
|
|
|
def create_text(self, filename_text):
|
|
|
|
|
text = random.choice(self.lines)
|
|
|
|
|
text = text.replace("[name]", self.placeholder_token)
|
|
|
|
|
tags = filename_text.split(',')
|
|
|
|
|
if shared.opt.tag_drop_out != 0:
|
|
|
|
|
tags = [t for t in tags if random.random() > shared.opt.tag_drop_out]
|
|
|
|
|
if shared.opts.shuffle_tags:
|
|
|
|
|
tags = filename_text.split(',')
|
|
|
|
|
random.shuffle(tags)
|
|
|
|
|
text = text.replace("[filewords]", ','.join(tags))
|
|
|
|
|
else:
|
|
|
|
|
text = text.replace("[filewords]", filename_text)
|
|
|
|
|
text = text.replace("[filewords]", ','.join(tags))
|
|
|
|
|
return text
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
|