|
|
|
|
@ -19,37 +19,21 @@ from modules.textual_inversion import textual_inversion
|
|
|
|
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_layer_structure(dim, state_dict):
|
|
|
|
|
i = 0
|
|
|
|
|
res = [1]
|
|
|
|
|
while (key := "linear.{}.weight".format(i)) in state_dict:
|
|
|
|
|
weight = state_dict[key]
|
|
|
|
|
res.append(len(weight) // dim)
|
|
|
|
|
i += 1
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HypernetworkModule(torch.nn.Module):
|
|
|
|
|
multiplier = 1.0
|
|
|
|
|
layer_structure = None
|
|
|
|
|
add_layer_norm = False
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim, state_dict=None):
|
|
|
|
|
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
|
|
|
|
|
super().__init__()
|
|
|
|
|
if (state_dict is None or 'linear.0.weight' not in state_dict) and self.layer_structure is None:
|
|
|
|
|
layer_structure = (1, 2, 1)
|
|
|
|
|
if layer_structure is not None:
|
|
|
|
|
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
|
|
|
|
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
|
|
|
|
else:
|
|
|
|
|
if self.layer_structure is not None:
|
|
|
|
|
assert self.layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
|
|
|
|
assert self.layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
|
|
|
|
layer_structure = self.layer_structure
|
|
|
|
|
else:
|
|
|
|
|
layer_structure = parse_layer_structure(dim, state_dict)
|
|
|
|
|
layer_structure = parse_layer_structure(dim, state_dict)
|
|
|
|
|
|
|
|
|
|
linears = []
|
|
|
|
|
for i in range(len(layer_structure) - 1):
|
|
|
|
|
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
|
|
|
|
if self.add_layer_norm:
|
|
|
|
|
if add_layer_norm:
|
|
|
|
|
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
|
|
|
|
|
|
|
|
|
self.linear = torch.nn.Sequential(*linears)
|
|
|
|
|
@ -77,38 +61,47 @@ class HypernetworkModule(torch.nn.Module):
|
|
|
|
|
return x + self.linear(x) * self.multiplier
|
|
|
|
|
|
|
|
|
|
def trainables(self):
|
|
|
|
|
res = []
|
|
|
|
|
layer_structure = []
|
|
|
|
|
for layer in self.linear:
|
|
|
|
|
res += [layer.weight, layer.bias]
|
|
|
|
|
return res
|
|
|
|
|
layer_structure += [layer.weight, layer.bias]
|
|
|
|
|
return layer_structure
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def apply_strength(value=None):
|
|
|
|
|
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def apply_layer_structure(value=None):
|
|
|
|
|
HypernetworkModule.layer_structure = value if value is not None else shared.opts.sd_hypernetwork_layer_structure
|
|
|
|
|
def parse_layer_structure(dim, state_dict):
|
|
|
|
|
i = 0
|
|
|
|
|
layer_structure = [1]
|
|
|
|
|
|
|
|
|
|
while (key := "linear.{}.weight".format(i)) in state_dict:
|
|
|
|
|
weight = state_dict[key]
|
|
|
|
|
layer_structure.append(len(weight) // dim)
|
|
|
|
|
i += 1
|
|
|
|
|
|
|
|
|
|
def apply_layer_norm(value=None):
|
|
|
|
|
HypernetworkModule.add_layer_norm = value if value is not None else shared.opts.sd_hypernetwork_add_layer_norm
|
|
|
|
|
return layer_structure
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Hypernetwork:
|
|
|
|
|
filename = None
|
|
|
|
|
name = None
|
|
|
|
|
|
|
|
|
|
def __init__(self, name=None, enable_sizes=None):
|
|
|
|
|
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False):
|
|
|
|
|
self.filename = None
|
|
|
|
|
self.name = name
|
|
|
|
|
self.layers = {}
|
|
|
|
|
self.step = 0
|
|
|
|
|
self.sd_checkpoint = None
|
|
|
|
|
self.sd_checkpoint_name = None
|
|
|
|
|
self.layer_structure = layer_structure
|
|
|
|
|
self.add_layer_norm = add_layer_norm
|
|
|
|
|
|
|
|
|
|
for size in enable_sizes or []:
|
|
|
|
|
self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
|
|
|
|
|
self.layers[size] = (
|
|
|
|
|
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
|
|
|
|
|
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def weights(self):
|
|
|
|
|
res = []
|
|
|
|
|
@ -128,6 +121,8 @@ class Hypernetwork:
|
|
|
|
|
|
|
|
|
|
state_dict['step'] = self.step
|
|
|
|
|
state_dict['name'] = self.name
|
|
|
|
|
state_dict['layer_structure'] = self.layer_structure
|
|
|
|
|
state_dict['is_layer_norm'] = self.add_layer_norm
|
|
|
|
|
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
|
|
|
|
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
|
|
|
|
|
|
|
|
|
@ -142,10 +137,15 @@ class Hypernetwork:
|
|
|
|
|
|
|
|
|
|
for size, sd in state_dict.items():
|
|
|
|
|
if type(size) == int:
|
|
|
|
|
self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
|
|
|
|
|
self.layers[size] = (
|
|
|
|
|
HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]),
|
|
|
|
|
HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.name = state_dict.get('name', self.name)
|
|
|
|
|
self.step = state_dict.get('step', 0)
|
|
|
|
|
self.layer_structure = state_dict.get('layer_structure', None)
|
|
|
|
|
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
|
|
|
|
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
|
|
|
|
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
|
|
|
|
|
|
|
|
|
|