|
|
|
|
@ -10,6 +10,7 @@ from tqdm import tqdm
|
|
|
|
|
from modules import modelloader
|
|
|
|
|
from modules.shared import cmd_opts, opts, device
|
|
|
|
|
from modules.swinir_model_arch import SwinIR as net
|
|
|
|
|
from modules.swinir_model_arch_v2 import Swin2SR as net2
|
|
|
|
|
from modules.upscaler import Upscaler, UpscalerData
|
|
|
|
|
|
|
|
|
|
precision_scope = (
|
|
|
|
|
@ -57,22 +58,42 @@ class UpscalerSwinIR(Upscaler):
|
|
|
|
|
filename = path
|
|
|
|
|
if filename is None or not os.path.exists(filename):
|
|
|
|
|
return None
|
|
|
|
|
model = net(
|
|
|
|
|
if filename.endswith(".v2.pth"):
|
|
|
|
|
model = net2(
|
|
|
|
|
upscale=scale,
|
|
|
|
|
in_chans=3,
|
|
|
|
|
img_size=64,
|
|
|
|
|
window_size=8,
|
|
|
|
|
img_range=1.0,
|
|
|
|
|
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
|
|
|
|
|
embed_dim=240,
|
|
|
|
|
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
|
|
|
|
|
depths=[6, 6, 6, 6, 6, 6],
|
|
|
|
|
embed_dim=180,
|
|
|
|
|
num_heads=[6, 6, 6, 6, 6, 6],
|
|
|
|
|
mlp_ratio=2,
|
|
|
|
|
upsampler="nearest+conv",
|
|
|
|
|
resi_connection="3conv",
|
|
|
|
|
)
|
|
|
|
|
resi_connection="1conv",
|
|
|
|
|
)
|
|
|
|
|
params = None
|
|
|
|
|
else:
|
|
|
|
|
model = net(
|
|
|
|
|
upscale=scale,
|
|
|
|
|
in_chans=3,
|
|
|
|
|
img_size=64,
|
|
|
|
|
window_size=8,
|
|
|
|
|
img_range=1.0,
|
|
|
|
|
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
|
|
|
|
|
embed_dim=240,
|
|
|
|
|
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
|
|
|
|
|
mlp_ratio=2,
|
|
|
|
|
upsampler="nearest+conv",
|
|
|
|
|
resi_connection="3conv",
|
|
|
|
|
)
|
|
|
|
|
params = "params_ema"
|
|
|
|
|
|
|
|
|
|
pretrained_model = torch.load(filename)
|
|
|
|
|
model.load_state_dict(pretrained_model["params_ema"], strict=True)
|
|
|
|
|
if params is not None:
|
|
|
|
|
model.load_state_dict(pretrained_model[params], strict=True)
|
|
|
|
|
else:
|
|
|
|
|
model.load_state_dict(pretrained_model, strict=True)
|
|
|
|
|
if not cmd_opts.no_half:
|
|
|
|
|
model = model.half()
|
|
|
|
|
return model
|
|
|
|
|
|