@ -19,9 +19,10 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry :
def __init__ ( self , filename = None , filename_text = None , latent_dist = None , latent_sample = None , cond = None , cond_text = None , pixel_values = None ):
def __init__ ( self , filename = None , filename_text = None , latent_dist = None , latent_sample = None , cond = None , cond_text = None , pixel_values = None , weight = None ):
self . filename = filename
self . filename_text = filename_text
self . weight = weight
self . latent_dist = latent_dist
self . latent_sample = latent_sample
self . cond = cond
@ -30,7 +31,7 @@ class DatasetEntry:
class PersonalizedBase ( Dataset ) :
def __init__ ( self , data_root , width , height , repeats , flip_p = 0.5 , placeholder_token = " * " , model = None , cond_model = None , device = None , template_file = None , include_cond = False , batch_size = 1 , gradient_step = 1 , shuffle_tags = False , tag_drop_out = 0 , latent_sampling_method = ' once ' , varsize = False ):
def __init__ ( self , data_root , width , height , repeats , flip_p = 0.5 , placeholder_token = " * " , model = None , cond_model = None , device = None , template_file = None , include_cond = False , batch_size = 1 , gradient_step = 1 , shuffle_tags = False , tag_drop_out = 0 , latent_sampling_method = ' once ' , varsize = False , use_weight = False ):
re_word = re . compile ( shared . opts . dataset_filename_word_regex ) if len ( shared . opts . dataset_filename_word_regex ) > 0 else None
self . placeholder_token = placeholder_token
@ -56,10 +57,16 @@ class PersonalizedBase(Dataset):
print ( " Preparing dataset... " )
for path in tqdm . tqdm ( self . image_paths ) :
alpha_channel = None
if shared . state . interrupted :
raise Exception ( " interrupted " )
try :
image = Image . open ( path ) . convert ( ' RGB ' )
image = Image . open ( path )
#Currently does not work for single color transparency
#We would need to read image.info['transparency'] for that
if use_weight and ' A ' in image . getbands ( ) :
alpha_channel = image . getchannel ( ' A ' )
image = image . convert ( ' RGB ' )
if not varsize :
image = image . resize ( ( width , height ) , PIL . Image . BICUBIC )
except Exception :
@ -87,17 +94,35 @@ class PersonalizedBase(Dataset):
with devices . autocast ( ) :
latent_dist = model . encode_first_stage ( torchdata . unsqueeze ( dim = 0 ) )
if latent_sampling_method == " once " or ( latent_sampling_method == " deterministic " and not isinstance ( latent_dist , DiagonalGaussianDistribution ) ) :
latent_sample = model . get_first_stage_encoding ( latent_dist ) . squeeze ( ) . to ( devices . cpu )
latent_sampling_method = " once "
entry = DatasetEntry ( filename = path , filename_text = filename_text , latent_sample = latent_sample )
elif latent_sampling_method == " deterministic " :
#Perform latent sampling, even for random sampling.
#We need the sample dimensions for the weights
if latent_sampling_method == " deterministic " :
if isinstance ( latent_dist , DiagonalGaussianDistribution ) :
# Works only for DiagonalGaussianDistribution
latent_dist . std = 0
else :
latent_sampling_method = " once "
latent_sample = model . get_first_stage_encoding ( latent_dist ) . squeeze ( ) . to ( devices . cpu )
entry = DatasetEntry ( filename = path , filename_text = filename_text , latent_sample = latent_sample )
elif latent_sampling_method == " random " :
entry = DatasetEntry ( filename = path , filename_text = filename_text , latent_dist = latent_dist )
if use_weight and alpha_channel is not None :
channels , * latent_size = latent_sample . shape
weight_img = alpha_channel . resize ( latent_size )
npweight = np . array ( weight_img ) . astype ( np . float32 )
#Repeat for every channel in the latent sample
weight = torch . tensor ( [ npweight ] * channels ) . reshape ( [ channels ] + latent_size )
#Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
weight - = weight . min ( )
weight / = weight . mean ( )
elif use_weight :
#If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
weight = torch . ones ( [ channels ] + latent_size )
else :
weight = None
if latent_sampling_method == " random " :
entry = DatasetEntry ( filename = path , filename_text = filename_text , latent_dist = latent_dist , weight = weight )
else :
entry = DatasetEntry ( filename = path , filename_text = filename_text , latent_sample = latent_sample , weight = weight )
if not ( self . tag_drop_out != 0 or self . shuffle_tags ) :
entry . cond_text = self . create_text ( filename_text )
@ -110,6 +135,7 @@ class PersonalizedBase(Dataset):
del torchdata
del latent_dist
del latent_sample
del weight
self . length = len ( self . dataset )
self . groups = list ( groups . values ( ) )
@ -195,6 +221,10 @@ class BatchLoader:
self . cond_text = [ entry . cond_text for entry in data ]
self . cond = [ entry . cond for entry in data ]
self . latent_sample = torch . stack ( [ entry . latent_sample for entry in data ] ) . squeeze ( 1 )
if all ( entry . weight is not None for entry in data ) :
self . weight = torch . stack ( [ entry . weight for entry in data ] ) . squeeze ( 1 )
else :
self . weight = None
#self.emb_index = [entry.emb_index for entry in data]
#print(self.latent_sample.device)