@ -2,6 +2,7 @@ import os
import sys
import traceback
from collections import namedtuple
from pathlib import Path
import re
import torch
@ -20,12 +21,16 @@ Category = namedtuple("Category", ["name", "topn", "items"])
re_topn = re . compile ( r " \ .top( \ d+) \ . " )
category_types = [ " artists " , " flavors " , " mediums " , " movements " ]
def category_types ( ) :
return [ f . stem for f in Path ( shared . interrogator . content_dir ) . glob ( ' *.txt ' ) ]
def download_default_clip_interrogate_categories ( content_dir ) :
print ( " Downloading CLIP categories... " )
tmpdir = content_dir + " _tmp "
category_types = [ " artists " , " flavors " , " mediums " , " movements " ]
try :
os . makedirs ( tmpdir )
for category_type in category_types :
@ -48,33 +53,32 @@ class InterrogateModels:
def __init__ ( self , content_dir ) :
self . loaded_categories = None
self . s elected _categories = [ ]
self . s kip _categories = [ ]
self . content_dir = content_dir
self . running_on_cpu = devices . device_interrogate == torch . device ( " cpu " )
def categories ( self ) :
if self . loaded_categories is not None and self . selected_categories == shared . opts . interrogate_clip_categories :
if not os . path . exists ( self . content_dir ) :
download_default_clip_interrogate_categories ( self . content_dir )
if self . loaded_categories is not None and self . skip_categories == shared . opts . interrogate_clip_skip_categories :
return self . loaded_categories
self . loaded_categories = [ ]
if not os . path . exists ( self . content_dir ) :
download_default_clip_interrogate_categories ( self . content_dir )
if os . path . exists ( self . content_dir ) :
self . selected_categories = shared . opts . interrogate_clip_categories
for category_type in category_types :
if ' all ' not in self . selected_categories and category_type not in self . selected_categories :
continue
filename = os . path . join ( self . content_dir , f " { category_type } .txt " )
if not os . path . isfile ( filename ) :
self . skip_categories = shared . opts . interrogate_clip_skip_categories
category_types = [ ]
for filename in Path ( self . content_dir ) . glob ( ' *.txt ' ) :
category_types . append ( filename . stem )
if filename . stem in self . skip_categories :
continue
m = re_topn . search ( filename )
m = re_topn . search ( filename . stem )
topn = 1 if m is None else int ( m . group ( 1 ) )
with open ( filename , " r " , encoding = " utf8 " ) as file :
lines = [ x . strip ( ) for x in file . readlines ( ) ]
self . loaded_categories . append ( Category ( name = category_type , topn = topn , items = lines ) )
self . loaded_categories . append ( Category ( name = filename. stem , topn = topn , items = lines ) )
return self . loaded_categories