|
|
|
|
@ -10,6 +10,7 @@ import torch
|
|
|
|
|
import numpy
|
|
|
|
|
import _codecs
|
|
|
|
|
import zipfile
|
|
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
|
|
|
|
@ -54,11 +55,27 @@ class RestrictedUnpickler(pickle.Unpickler):
|
|
|
|
|
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
allowed_zip_names = ["archive/data.pkl", "archive/version"]
|
|
|
|
|
allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_zip_filenames(filename, names):
|
|
|
|
|
for name in names:
|
|
|
|
|
if name in allowed_zip_names:
|
|
|
|
|
continue
|
|
|
|
|
if allowed_zip_names_re.match(name):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
raise Exception(f"bad file inside {filename}: {name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_pt(filename):
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
|
|
# new pytorch format is a zip file
|
|
|
|
|
with zipfile.ZipFile(filename) as z:
|
|
|
|
|
check_zip_filenames(filename, z.namelist())
|
|
|
|
|
|
|
|
|
|
with z.open('archive/data.pkl') as file:
|
|
|
|
|
unpickler = RestrictedUnpickler(file)
|
|
|
|
|
unpickler.load()
|
|
|
|
|
|