@ -1,28 +1,32 @@
import csv
import datetime
import glob
import html
import os
import sys
import traceback
import tqdm
import csv
import modules . textual_inversion . dataset
import torch
from ldm . util import default
from modules import devices , shared , processing , sd_models
import torch
from torch import einsum
import tqdm
from einops import rearrange , repeat
import modules . textual_inversion . dataset
from ldm . util import default
from modules import devices , processing , sd_models , shared
from modules . textual_inversion import textual_inversion
from modules . textual_inversion . learn_schedule import LearnRateScheduler
from torch import einsum
class HypernetworkModule ( torch . nn . Module ) :
multiplier = 1.0
def __init__ ( self , dim , state_dict = None , layer_structure = None , add_layer_norm = False , activation_func = None ) :
activation_dict = {
" relu " : torch . nn . ReLU ,
" leakyrelu " : torch . nn . LeakyReLU ,
" elu " : torch . nn . ELU ,
" swish " : torch . nn . Hardswish ,
}
def __init__ ( self , dim , state_dict = None , layer_structure = None , activation_func = None , add_layer_norm = False , use_dropout = False ) :
super ( ) . __init__ ( )
assert layer_structure is not None , " layer_structure must not be None "
@ -31,20 +35,26 @@ class HypernetworkModule(torch.nn.Module):
linears = [ ]
for i in range ( len ( layer_structure ) - 1 ) :
# Add a fully-connected layer
linears . append ( torch . nn . Linear ( int ( dim * layer_structure [ i ] ) , int ( dim * layer_structure [ i + 1 ] ) ) )
if activation_func == " relu " :
linears . append ( torch . nn . ReLU ( ) )
elif activation_func == " leakyrelu " :
linears . append ( torch . nn . LeakyReLU ( ) )
elif activation_func == ' linear ' or activation_func is None :
# Add an activation func
if activation_func == " linear " or activation_func is None :
pass
elif activation_func in self . activation_dict :
linears . append ( self . activation_dict [ activation_func ] ( ) )
else :
raise RuntimeError ( f ' hypernetwork uses an unsupported activation function: { activation_func } ' )
# Add layer normalization
if add_layer_norm :
linears . append ( torch . nn . LayerNorm ( int ( dim * layer_structure [ i + 1 ] ) ) )
# Add dropout expect last layer
if use_dropout and i < len ( layer_structure ) - 3 :
linears . append ( torch . nn . Dropout ( p = 0.3 ) )
self . linear = torch . nn . Sequential ( * linears )
if state_dict is not None :
@ -93,7 +103,7 @@ class Hypernetwork:
filename = None
name = None
def __init__ ( self , name = None , enable_sizes = None , layer_structure = None , a dd_layer_norm= False , activation_func= Non e) :
def __init__ ( self , name = None , enable_sizes = None , layer_structure = None , a ctivation_func= None , a dd_layer_norm= False , use_dropout= Fals e) :
self . filename = None
self . name = name
self . layers = { }
@ -101,13 +111,14 @@ class Hypernetwork:
self . sd_checkpoint = None
self . sd_checkpoint_name = None
self . layer_structure = layer_structure
self . add_layer_norm = add_layer_norm
self . activation_func = activation_func
self . add_layer_norm = add_layer_norm
self . use_dropout = use_dropout
for size in enable_sizes or [ ] :
self . layers [ size ] = (
HypernetworkModule ( size , None , self . layer_structure , self . a dd_layer_norm, self . activation_func ) ,
HypernetworkModule ( size , None , self . layer_structure , self . a dd_layer_norm, self . activation_func ) ,
HypernetworkModule ( size , None , self . layer_structure , self . a ctivation_func, self . a dd_layer_norm, self . use_dropout ) ,
HypernetworkModule ( size , None , self . layer_structure , self . a ctivation_func, self . a dd_layer_norm, self . use_dropout ) ,
)
def weights ( self ) :
@ -129,8 +140,9 @@ class Hypernetwork:
state_dict [ ' step ' ] = self . step
state_dict [ ' name ' ] = self . name
state_dict [ ' layer_structure ' ] = self . layer_structure
state_dict [ ' is_layer_norm ' ] = self . add_layer_norm
state_dict [ ' activation_func ' ] = self . activation_func
state_dict [ ' is_layer_norm ' ] = self . add_layer_norm
state_dict [ ' use_dropout ' ] = self . use_dropout
state_dict [ ' sd_checkpoint ' ] = self . sd_checkpoint
state_dict [ ' sd_checkpoint_name ' ] = self . sd_checkpoint_name
@ -144,14 +156,15 @@ class Hypernetwork:
state_dict = torch . load ( filename , map_location = ' cpu ' )
self . layer_structure = state_dict . get ( ' layer_structure ' , [ 1 , 2 , 1 ] )
self . add_layer_norm = state_dict . get ( ' is_layer_norm ' , False )
self . activation_func = state_dict . get ( ' activation_func ' , None )
self . add_layer_norm = state_dict . get ( ' is_layer_norm ' , False )
self . use_dropout = state_dict . get ( ' use_dropout ' , False )
for size , sd in state_dict . items ( ) :
if type ( size ) == int :
self . layers [ size ] = (
HypernetworkModule ( size , sd [ 0 ] , self . layer_structure , self . a dd_layer_norm, self . activation_func ) ,
HypernetworkModule ( size , sd [ 1 ] , self . layer_structure , self . a dd_layer_norm, self . activation_func ) ,
HypernetworkModule ( size , sd [ 0 ] , self . layer_structure , self . a ctivation_func, self . a dd_layer_norm, self . use_dropout ) ,
HypernetworkModule ( size , sd [ 1 ] , self . layer_structure , self . a ctivation_func, self . a dd_layer_norm, self . use_dropout ) ,
)
self . name = state_dict . get ( ' name ' , self . name )
@ -308,6 +321,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
return hypernetwork , filename
scheduler = LearnRateScheduler ( learn_rate , steps , ititial_step )
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch . optim . AdamW ( weights , lr = scheduler . learn_rate )
steps_without_grad = 0