@ -7,6 +7,7 @@ import shlex
import platform
import argparse
import json
import detection
dir_repos = " repositories "
dir_extensions = " extensions "
@ -15,6 +16,12 @@ git = os.environ.get('GIT', "git")
index_url = os . environ . get ( ' INDEX_URL ' , " " )
stored_commit_hash = None
# Get the GPU vendor and the operating system
gpu = detection . check_gpu ( )
if os . name == " posix " :
os_name = platform . uname ( ) . system
else :
os_name = os . name
def commit_hash ( ) :
global stored_commit_hash
@ -173,7 +180,11 @@ def run_extensions_installers(settings_file):
def prepare_environment ( ) :
torch_command = os . environ . get ( ' TORCH_COMMAND ' , " pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 " )
if gpu == " AMD " and os_name != " nt " :
torch_command = os . environ . get ( ' TORCH_COMMAND ' , " pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2 " )
else :
torch_command = os . environ . get ( ' TORCH_COMMAND ' , " pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 " )
requirements_file = os . environ . get ( ' REQS_FILE ' , " requirements_versions.txt " )
commandline_args = os . environ . get ( ' COMMANDLINE_ARGS ' , " " )
@ -295,6 +306,8 @@ def tests(test_dir):
def start ( ) :
print ( f " Operating System: { os_name } " )
print ( f " GPU: { gpu } " )
print ( f " Launching { ' API server ' if ' --nowebui ' in sys . argv else ' Web UI ' } with arguments: { ' ' . join ( sys . argv [ 1 : ] ) } " )
import webui
if ' --nowebui ' in sys . argv :