|
|
|
|
@ -62,14 +62,12 @@ class RestrictedUnpickler(pickle.Unpickler):
|
|
|
|
|
raise Exception(f"global '{module}/{name}' is forbidden")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
allowed_zip_names = ["archive/data.pkl", "archive/version"]
|
|
|
|
|
allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
|
|
|
|
|
|
|
|
|
|
# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
|
|
|
|
|
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
|
|
|
|
|
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
|
|
|
|
|
|
|
|
|
|
def check_zip_filenames(filename, names):
|
|
|
|
|
for name in names:
|
|
|
|
|
if name in allowed_zip_names:
|
|
|
|
|
continue
|
|
|
|
|
if allowed_zip_names_re.match(name):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
@ -82,8 +80,14 @@ def check_pt(filename, extra_handler):
|
|
|
|
|
# new pytorch format is a zip file
|
|
|
|
|
with zipfile.ZipFile(filename) as z:
|
|
|
|
|
check_zip_filenames(filename, z.namelist())
|
|
|
|
|
|
|
|
|
|
with z.open('archive/data.pkl') as file:
|
|
|
|
|
|
|
|
|
|
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
|
|
|
|
|
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
|
|
|
|
|
if len(data_pkl_filenames) == 0:
|
|
|
|
|
raise Exception(f"data.pkl not found in {filename}")
|
|
|
|
|
if len(data_pkl_filenames) > 1:
|
|
|
|
|
raise Exception(f"Multiple data.pkl found in {filename}")
|
|
|
|
|
with z.open(data_pkl_filenames[0]) as file:
|
|
|
|
|
unpickler = RestrictedUnpickler(file)
|
|
|
|
|
unpickler.extra_handler = extra_handler
|
|
|
|
|
unpickler.load()
|
|
|
|
|
|