def compute_fft(args, dataloader):
#device = args.device if args.device == "cpu" else int(args.device)
device = 0 # Do not modify
h, w = args.img_size, args.img_size
lpf = torch.zeros((h, w))
R = (h + w) // 8
for x in range(w):
for y in range(h):
if ((x - w/2)**2 + (y - h/2)**2) < R**2:
lpf[y, x] = 1
hpf = 1 - lpf
hpf, lpf = hpf.to(device), lpf.to(device)
with torch.no_grad():
for i, (image_name, img, lbl) in enumerate(dataloader):
print("img: %d" % i, image_name, 'label:', lbl)
img = img.to(device)
lbl = lbl.to(device)
# print('img size', img.size()) # [1, 3, 224, 224]
# img = normalize(img)
fft_img = torch.fft.fftn(img, dim=(2, 3))
# print('fft size', fft_img.size()) # [1, 3, 224, 224]
# put low_freq into the center of img
fft_img = torch.roll(fft_img, (h//2, w//2), dims=(2, 3))
f_low = fft_img * lpf
f_high = fft_img * hpf
X_low = torch.abs(torch.fft.ifftn(f_low, dim=(2,3)))
X_high = torch.abs(torch.fft.ifftn(f_high, dim=(2,3)))
X_low_pil = transforms.ToPILImage()((torch.squeeze(X_low, dim=0)).float())
X_low_pil.save('low_' + image_name[0])
X_high_pil = transforms.ToPILImage()((torch.squeeze(X_high, dim=0)).float())
X_high_pil.save('high_' + image_name[0])