@ -140,7 +140,7 @@ def run_pnginfo(image):
return ' ' , geninfo , info
def run_modelmerger ( modelname_0 , modelname_1 , interp_method , interp_amount ) :
def run_modelmerger ( primary_ model_ name, secondary_ model_ name, interp_method , interp_amount ) :
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
def weighted_sum ( theta0 , theta1 , alpha ) :
return ( ( 1 - alpha ) * theta0 ) + ( alpha * theta1 )
@ -150,26 +150,26 @@ def run_modelmerger(modelname_0, modelname_1, interp_method, interp_amount):
alpha = alpha * alpha * ( 3 - ( 2 * alpha ) )
return theta0 + ( ( theta1 - theta0 ) * alpha )
if os . path . exists ( modelname_0 ) :
model0 _filename = modelname_0
modelname_0 = os . path . splitext ( os . path . basename ( modelname_0 ) ) [ 0 ]
if os . path . exists ( primary_ model_ name) :
primary_ model_filename = primary_ model_ name
primary_ model_ name = os . path . splitext ( os . path . basename ( primary_ model_ name) ) [ 0 ]
else :
model0 _filename = ' models/ ' + modelname_0 + ' .ckpt '
primary_ model_filename = ' models/ ' + primary_ model_ name + ' .ckpt '
if os . path . exists ( modelname_1 ) :
model1 _filename = modelname_1
modelname_1 = os . path . splitext ( os . path . basename ( modelname_1 ) ) [ 0 ]
if os . path . exists ( secondary_ model_ name) :
secondary_ model_filename = secondary_ model_ name
secondary_ model_ name = os . path . splitext ( os . path . basename ( secondary_ model_ name) ) [ 0 ]
else :
model1 _filename = ' models/ ' + modelname_1 + ' .ckpt '
secondary_ model_filename = ' models/ ' + secondary_ model_ name + ' .ckpt '
print ( f " Loading { model0 _filename} ... " )
model_0 = torch . load ( model0 _filename, map_location = ' cpu ' )
print ( f " Loading { primary_ model_filename} ... " )
primary_ model = torch . load ( primary_ model_filename, map_location = ' cpu ' )
print ( f " Loading { model1 _filename} ... " )
model_1 = torch . load ( model1 _filename, map_location = ' cpu ' )
theta_0 = model_0 [ ' state_dict ' ]
theta_1 = model_1 [ ' state_dict ' ]
print ( f " Loading { secondary_ model_filename} ... " )
secondary_ model = torch . load ( secondary_ model_filename, map_location = ' cpu ' )
theta_0 = primary_ model[ ' state_dict ' ]
theta_1 = secondary_ model[ ' state_dict ' ]
theta_funcs = {
" Weighted Sum " : weighted_sum ,
@ -180,15 +180,15 @@ def run_modelmerger(modelname_0, modelname_1, interp_method, interp_amount):
print ( f " Merging... " )
for key in tqdm . tqdm ( theta_0 . keys ( ) ) :
if ' model ' in key and key in theta_1 :
theta_0 [ key ] = theta_func ( theta_0 [ key ] , theta_1 [ key ] , interp_amount )
theta_0 [ key ] = theta_func ( theta_0 [ key ] , theta_1 [ key ] , ( float ( 1.0 ) - interp_amount ) ) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
for key in theta_1 . keys ( ) :
if ' model ' in key and key not in theta_0 :
theta_0 [ key ] = theta_1 [ key ]
output_modelname = ' models/ ' + modelname_0 + ' - ' + modelname_1 + ' - ' + interp_method . replace ( " " , " _ " ) + ' - ' + str ( interp_amount ) + ' -merged.ckpt '
output_modelname = ' models/ ' + primary_model_name + ' _ ' + str ( round ( interp_amount , 2 ) ) + ' - ' + secondary_ model_ name + ' _' + str ( round ( ( float ( 1.0) - interp_amount ) , 2 ) ) + ' - ' + interp_method . replace ( " " , " _ " ) + ' -merged.ckpt '
print ( f " Saving to { output_modelname } ... " )
torch . save ( model_0 , output_modelname )
torch . save ( primary_ model, output_modelname )
print ( f " Checkpoint saved. " )
return " Checkpoint saved to " + output_modelname
return " Checkpoint saved to " + output_modelname