|
|
|
|
@ -52,7 +52,7 @@ class HypernetworkModule(torch.nn.Module):
|
|
|
|
|
self.load_state_dict(state_dict)
|
|
|
|
|
else:
|
|
|
|
|
for layer in self.linear:
|
|
|
|
|
if type(layer) == torch.nn.Linear:
|
|
|
|
|
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
|
|
|
|
layer.weight.data.normal_(mean=0.0, std=0.01)
|
|
|
|
|
layer.bias.data.zero_()
|
|
|
|
|
|
|
|
|
|
@ -80,7 +80,7 @@ class HypernetworkModule(torch.nn.Module):
|
|
|
|
|
def trainables(self):
|
|
|
|
|
layer_structure = []
|
|
|
|
|
for layer in self.linear:
|
|
|
|
|
if type(layer) == torch.nn.Linear:
|
|
|
|
|
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
|
|
|
|
layer_structure += [layer.weight, layer.bias]
|
|
|
|
|
return layer_structure
|
|
|
|
|
|
|
|
|
|
|