@ -5,6 +5,7 @@ import html
import os
import sys
import traceback
import inspect
import modules . textual_inversion . dataset
import torch
@ -15,10 +16,12 @@ from modules import devices, processing, sd_models, shared
from modules . textual_inversion import textual_inversion
from modules . textual_inversion . learn_schedule import LearnRateScheduler
from torch import einsum
from torch . nn . init import normal_ , xavier_normal_ , xavier_uniform_ , kaiming_normal_ , kaiming_uniform_ , zeros_
from collections import defaultdict , deque
from statistics import stdev , mean
class HypernetworkModule ( torch . nn . Module ) :
multiplier = 1.0
activation_dict = {
@ -26,9 +29,12 @@ class HypernetworkModule(torch.nn.Module):
" leakyrelu " : torch . nn . LeakyReLU ,
" elu " : torch . nn . ELU ,
" swish " : torch . nn . Hardswish ,
" tanh " : torch . nn . Tanh ,
" sigmoid " : torch . nn . Sigmoid ,
}
activation_dict . update ( { cls_name : cls_obj for cls_name , cls_obj in inspect . getmembers ( torch . nn . modules . activation ) if inspect . isclass ( cls_obj ) and cls_obj . __module__ == ' torch.nn.modules.activation ' } )
def __init__ ( self , dim , state_dict = None , layer_structure = None , activation_func = None , add_layer_norm= False , use_dropout = False ) :
def __init__ ( self , dim , state_dict = None , layer_structure = None , activation_func = None , weight_init= ' Normal ' , add_layer_norm= False , use_dropout = False ) :
super ( ) . __init__ ( )
assert layer_structure is not None , " layer_structure must not be None "
@ -65,9 +71,24 @@ class HypernetworkModule(torch.nn.Module):
else :
for layer in self . linear :
if type ( layer ) == torch . nn . Linear or type ( layer ) == torch . nn . LayerNorm :
layer . weight . data . normal_ ( mean = 0.0 , std = 0.01 )
layer . bias . data . zero_ ( )
w , b = layer . weight . data , layer . bias . data
if weight_init == " Normal " or type ( layer ) == torch . nn . LayerNorm :
normal_ ( w , mean = 0.0 , std = 0.01 )
normal_ ( b , mean = 0.0 , std = 0.005 )
elif weight_init == ' XavierUniform ' :
xavier_uniform_ ( w )
zeros_ ( b )
elif weight_init == ' XavierNormal ' :
xavier_normal_ ( w )
zeros_ ( b )
elif weight_init == ' KaimingUniform ' :
kaiming_uniform_ ( w , nonlinearity = ' leaky_relu ' if ' leakyrelu ' == activation_func else ' relu ' )
zeros_ ( b )
elif weight_init == ' KaimingNormal ' :
kaiming_normal_ ( w , nonlinearity = ' leaky_relu ' if ' leakyrelu ' == activation_func else ' relu ' )
zeros_ ( b )
else :
raise KeyError ( f " Key { weight_init } is not defined as initialization! " )
self . to ( devices . device )
def fix_old_state_dict ( self , state_dict ) :
@ -105,7 +126,7 @@ class Hypernetwork:
filename = None
name = None
def __init__ ( self , name = None , enable_sizes = None , layer_structure = None , activation_func = None , add_layer_norm= False , use_dropout = False ) :
def __init__ ( self , name = None , enable_sizes = None , layer_structure = None , activation_func = None , weight_init= None , add_layer_norm= False , use_dropout = False ) :
self . filename = None
self . name = name
self . layers = { }
@ -114,13 +135,14 @@ class Hypernetwork:
self . sd_checkpoint_name = None
self . layer_structure = layer_structure
self . activation_func = activation_func
self . weight_init = weight_init
self . add_layer_norm = add_layer_norm
self . use_dropout = use_dropout
for size in enable_sizes or [ ] :
self . layers [ size ] = (
HypernetworkModule ( size , None , self . layer_structure , self . activation_func , self . add_layer_norm, self . use_dropout ) ,
HypernetworkModule ( size , None , self . layer_structure , self . activation_func , self . add_layer_norm, self . use_dropout ) ,
HypernetworkModule ( size , None , self . layer_structure , self . activation_func , self . weight_init, self . add_layer_norm, self . use_dropout ) ,
HypernetworkModule ( size , None , self . layer_structure , self . activation_func , self . weight_init, self . add_layer_norm, self . use_dropout ) ,
)
def weights ( self ) :
@ -144,6 +166,7 @@ class Hypernetwork:
state_dict [ ' layer_structure ' ] = self . layer_structure
state_dict [ ' activation_func ' ] = self . activation_func
state_dict [ ' is_layer_norm ' ] = self . add_layer_norm
state_dict [ ' weight_initialization ' ] = self . weight_init
state_dict [ ' use_dropout ' ] = self . use_dropout
state_dict [ ' sd_checkpoint ' ] = self . sd_checkpoint
state_dict [ ' sd_checkpoint_name ' ] = self . sd_checkpoint_name
@ -158,15 +181,21 @@ class Hypernetwork:
state_dict = torch . load ( filename , map_location = ' cpu ' )
self . layer_structure = state_dict . get ( ' layer_structure ' , [ 1 , 2 , 1 ] )
print ( self . layer_structure )
self . activation_func = state_dict . get ( ' activation_func ' , None )
print ( f " Activation function is { self . activation_func } " )
self . weight_init = state_dict . get ( ' weight_initialization ' , ' Normal ' )
print ( f " Weight initialization is { self . weight_init } " )
self . add_layer_norm = state_dict . get ( ' is_layer_norm ' , False )
print ( f " Layer norm is set to { self . add_layer_norm } " )
self . use_dropout = state_dict . get ( ' use_dropout ' , False )
print ( f " Dropout usage is set to { self . use_dropout } " )
for size , sd in state_dict . items ( ) :
if type ( size ) == int :
self . layers [ size ] = (
HypernetworkModule ( size , sd [ 0 ] , self . layer_structure , self . activation_func , self . add_layer_norm, self . use_dropout ) ,
HypernetworkModule ( size , sd [ 1 ] , self . layer_structure , self . activation_func , self . add_layer_norm, self . use_dropout ) ,
HypernetworkModule ( size , sd [ 0 ] , self . layer_structure , self . activation_func , self . weight_init, self . add_layer_norm, self . use_dropout ) ,
HypernetworkModule ( size , sd [ 1 ] , self . layer_structure , self . activation_func , self . weight_init, self . add_layer_norm, self . use_dropout ) ,
)
self . name = state_dict . get ( ' name ' , self . name )