|
|
|
|
@ -189,8 +189,6 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
|
|
|
|
embedding = hijack.embedding_db.word_embeddings[embedding_name]
|
|
|
|
|
embedding.vec.requires_grad = True
|
|
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
|
|
|
|
|
|
|
|
|
|
losses = torch.zeros((32,))
|
|
|
|
|
|
|
|
|
|
last_saved_file = "<none>"
|
|
|
|
|
@ -200,12 +198,27 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
|
|
|
|
if ititial_step > steps:
|
|
|
|
|
return embedding, filename
|
|
|
|
|
|
|
|
|
|
tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
|
|
|
|
|
epoch_len = (tr_img_len * num_repeats) + tr_img_len
|
|
|
|
|
|
|
|
|
|
scheduleIter = iter(LearnSchedule(learn_rate, steps, ititial_step))
|
|
|
|
|
(learn_rate, end_step) = next(scheduleIter)
|
|
|
|
|
print(f'Training at rate of {learn_rate} until step {end_step}')
|
|
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
|
|
|
|
|
|
|
|
|
|
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
|
|
|
|
for i, (x, text, _) in pbar:
|
|
|
|
|
embedding.step = i + ititial_step
|
|
|
|
|
|
|
|
|
|
if embedding.step > steps:
|
|
|
|
|
break
|
|
|
|
|
if embedding.step > end_step:
|
|
|
|
|
try:
|
|
|
|
|
(learn_rate, end_step) = next(scheduleIter)
|
|
|
|
|
except:
|
|
|
|
|
break
|
|
|
|
|
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
|
|
|
|
|
for pg in optimizer.param_groups:
|
|
|
|
|
pg['lr'] = learn_rate
|
|
|
|
|
|
|
|
|
|
if shared.state.interrupted:
|
|
|
|
|
break
|
|
|
|
|
@ -276,3 +289,36 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|
|
|
|
|
|
|
|
|
return embedding, filename
|
|
|
|
|
|
|
|
|
|
class LearnSchedule:
|
|
|
|
|
def __init__(self, learn_rate, max_steps, cur_step=0):
|
|
|
|
|
pairs = learn_rate.split(',')
|
|
|
|
|
self.rates = []
|
|
|
|
|
self.it = 0
|
|
|
|
|
self.maxit = 0
|
|
|
|
|
for i, pair in enumerate(pairs):
|
|
|
|
|
tmp = pair.split(':')
|
|
|
|
|
if len(tmp) == 2:
|
|
|
|
|
step = int(tmp[1])
|
|
|
|
|
if step > cur_step:
|
|
|
|
|
self.rates.append((float(tmp[0]), min(step, max_steps)))
|
|
|
|
|
self.maxit += 1
|
|
|
|
|
if step > max_steps:
|
|
|
|
|
return
|
|
|
|
|
elif step == -1:
|
|
|
|
|
self.rates.append((float(tmp[0]), max_steps))
|
|
|
|
|
self.maxit += 1
|
|
|
|
|
return
|
|
|
|
|
else:
|
|
|
|
|
self.rates.append((float(tmp[0]), max_steps))
|
|
|
|
|
self.maxit += 1
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def __next__(self):
|
|
|
|
|
if self.it < self.maxit:
|
|
|
|
|
self.it += 1
|
|
|
|
|
return self.rates[self.it - 1]
|
|
|
|
|
else:
|
|
|
|
|
raise StopIteration
|
|
|
|
|
|