do not replace entire unet for the resolution hack
parent
2641d1b83b
commit
7dbfd8a7d8
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
|
||||
|
||||
class TorchHijackForUnet:
|
||||
"""
|
||||
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
||||
this makes it possible to create pictures with dimensions that are muliples of 8 rather than 64
|
||||
"""
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'cat':
|
||||
return self.cat
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
||||
|
||||
def cat(self, tensors, *args, **kwargs):
|
||||
if len(tensors) == 2:
|
||||
a, b = tensors
|
||||
if a.shape[-2:] != b.shape[-2:]:
|
||||
a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
|
||||
|
||||
tensors = (a, b)
|
||||
|
||||
return torch.cat(tensors, *args, **kwargs)
|
||||
|
||||
|
||||
th = TorchHijackForUnet()
|
||||
Loading…
Reference in New Issue