@ -20,6 +20,7 @@ Category = namedtuple("Category", ["name", "topn", "items"])
re_topn = re . compile ( r " \ .top( \ d+) \ . " )
re_topn = re . compile ( r " \ .top( \ d+) \ . " )
category_types = [ " artists " , " flavors " , " mediums " , " movements " ]
def download_default_clip_interrogate_categories ( content_dir ) :
def download_default_clip_interrogate_categories ( content_dir ) :
print ( " Downloading CLIP categories... " )
print ( " Downloading CLIP categories... " )
@ -27,12 +28,8 @@ def download_default_clip_interrogate_categories(content_dir):
tmpdir = content_dir + " _tmp "
tmpdir = content_dir + " _tmp "
try :
try :
os . makedirs ( tmpdir )
os . makedirs ( tmpdir )
for category_type in category_types :
torch . hub . download_url_to_file ( " https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt " , os . path . join ( tmpdir , " artists.txt " ) )
torch . hub . download_url_to_file ( f " https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/ { category_type } .txt " , os . path . join ( tmpdir , f " { category_type } .txt " ) )
torch . hub . download_url_to_file ( " https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt " , os . path . join ( tmpdir , " flavors.top3.txt " ) )
torch . hub . download_url_to_file ( " https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt " , os . path . join ( tmpdir , " mediums.txt " ) )
torch . hub . download_url_to_file ( " https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt " , os . path . join ( tmpdir , " movements.txt " ) )
os . rename ( tmpdir , content_dir )
os . rename ( tmpdir , content_dir )
except Exception as e :
except Exception as e :
@ -51,11 +48,12 @@ class InterrogateModels:
def __init__ ( self , content_dir ) :
def __init__ ( self , content_dir ) :
self . loaded_categories = None
self . loaded_categories = None
self . selected_categories = [ ]
self . content_dir = content_dir
self . content_dir = content_dir
self . running_on_cpu = devices . device_interrogate == torch . device ( " cpu " )
self . running_on_cpu = devices . device_interrogate == torch . device ( " cpu " )
def categories ( self ) :
def categories ( self ) :
if self . loaded_categories is not None :
if self . loaded_categories is not None and self . selected_categories == shared . opts . interrogate_clip_categories :
return self . loaded_categories
return self . loaded_categories
self . loaded_categories = [ ]
self . loaded_categories = [ ]
@ -64,14 +62,19 @@ class InterrogateModels:
download_default_clip_interrogate_categories ( self . content_dir )
download_default_clip_interrogate_categories ( self . content_dir )
if os . path . exists ( self . content_dir ) :
if os . path . exists ( self . content_dir ) :
for filename in os . listdir ( self . content_dir ) :
self . selected_categories = shared . opts . interrogate_clip_categories
for category_type in category_types :
if ' all ' not in self . selected_categories and category_type not in self . selected_categories :
continue
filename = os . path . join ( self . content_dir , f " { category_type } .txt " )
if not os . path . isfile ( filename ) :
continue
m = re_topn . search ( filename )
m = re_topn . search ( filename )
topn = 1 if m is None else int ( m . group ( 1 ) )
topn = 1 if m is None else int ( m . group ( 1 ) )
with open ( filename , " r " , encoding = " utf8 " ) as file :
with open ( os . path . join ( self . content_dir , filename ) , " r " , encoding = " utf8 " ) as file :
lines = [ x . strip ( ) for x in file . readlines ( ) ]
lines = [ x . strip ( ) for x in file . readlines ( ) ]
self . loaded_categories . append ( Category ( name = filenam e, topn = topn , items = lines ) )
self . loaded_categories . append ( Category ( name = category_typ e, topn = topn , items = lines ) )
return self . loaded_categories
return self . loaded_categories
@ -139,6 +142,8 @@ class InterrogateModels:
def rank ( self , image_features , text_array , top_count = 1 ) :
def rank ( self , image_features , text_array , top_count = 1 ) :
import clip
import clip
devices . torch_gc ( )
if shared . opts . interrogate_clip_dict_limit != 0 :
if shared . opts . interrogate_clip_dict_limit != 0 :
text_array = text_array [ 0 : int ( shared . opts . interrogate_clip_dict_limit ) ]
text_array = text_array [ 0 : int ( shared . opts . interrogate_clip_dict_limit ) ]