@ -13,6 +13,7 @@ from modules.ui import plaintext_to_html
import modules . codeformer_model
import piexif
import piexif . helper
import gradio as gr
cached_images = { }
@ -140,7 +141,7 @@ def run_pnginfo(image):
return ' ' , geninfo , info
def run_modelmerger ( primary_model_name , secondary_model_name , interp_method , interp_amount ):
def run_modelmerger ( primary_model_name , secondary_model_name , interp_method , interp_amount , save_as_half ):
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
def weighted_sum ( theta0 , theta1 , alpha ) :
return ( ( 1 - alpha ) * theta0 ) + ( alpha * theta1 )
@ -156,14 +157,14 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
alpha = 0.5 - math . sin ( math . asin ( 1.0 - 2.0 * alpha ) / 3.0 )
return theta0 + ( ( theta1 - theta0 ) * alpha )
primary_model_ filename = sd_models . checkpoints_list [ primary_model_name ] . filename
secondary_model_ filename = sd_models . checkpoints_list [ secondary_model_name ] . filename
primary_model_ info = sd_models . checkpoints_list [ primary_model_name ]
secondary_model_ info = sd_models . checkpoints_list [ secondary_model_name ]
print ( f " Loading { primary_model_ filename} ... " )
primary_model = torch . load ( primary_model_ filename, map_location = ' cpu ' )
print ( f " Loading { primary_model_ info. filename} ... " )
primary_model = torch . load ( primary_model_ info. filename, map_location = ' cpu ' )
print ( f " Loading { secondary_model_ filename} ... " )
secondary_model = torch . load ( secondary_model_ filename, map_location = ' cpu ' )
print ( f " Loading { secondary_model_ info. filename} ... " )
secondary_model = torch . load ( secondary_model_ info. filename, map_location = ' cpu ' )
theta_0 = primary_model [ ' state_dict ' ]
theta_1 = secondary_model [ ' state_dict ' ]
@ -179,16 +180,22 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
for key in tqdm . tqdm ( theta_0 . keys ( ) ) :
if ' model ' in key and key in theta_1 :
theta_0 [ key ] = theta_func ( theta_0 [ key ] , theta_1 [ key ] , ( float ( 1.0 ) - interp_amount ) ) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
if save_as_half :
theta_0 [ key ] = theta_0 [ key ] . half ( )
for key in theta_1 . keys ( ) :
if ' model ' in key and key not in theta_0 :
theta_0 [ key ] = theta_1 [ key ]
if save_as_half :
theta_0 [ key ] = theta_0 [ key ] . half ( )
filename = primary_model_name + ' _ ' + str ( round ( interp_amount , 2 ) ) + ' - ' + secondary_model_name + ' _ ' + str ( round ( ( float ( 1.0 ) - interp_amount ) , 2 ) ) + ' - ' + interp_method . replace ( " " , " _ " ) + ' -merged.ckpt '
filename = primary_model_ info. model_ name + ' _ ' + str ( round ( interp_amount , 2 ) ) + ' - ' + secondary_model_ info. model_ name + ' _ ' + str ( round ( ( float ( 1.0 ) - interp_amount ) , 2 ) ) + ' - ' + interp_method . replace ( " " , " _ " ) + ' -merged.ckpt '
output_modelname = os . path . join ( shared . cmd_opts . ckpt_dir , filename )
print ( f " Saving to { output_modelname } ... " )
torch . save ( primary_model , output_modelname )
sd_models . list_models ( )
print ( f " Checkpoint saved. " )
return " Checkpoint saved to " + output_modelname
return [ " Checkpoint saved to " + output_modelname ] + [ gr . Dropdown . update ( choices = sd_models . checkpoint_tiles ( ) ) for _ in range ( 3 ) ]