Face De-Blurring
Solution for submission 174153
A detailed solution for submission 174153 submitted for challenge Face De-Blurring
Training Loop
In [ ]:
from cv2 import repeat
import numpy as np
import os
import matplotlib.pyplot as plt
import glob
import cv2
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import argparse
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torchvision.utils import save_image
# from skimage.metrics import structural_similarity
import pytorch_msssim
# from piqa.ssim import ssim
# from piqa.utils.functional import gaussian_kernel
# constructing the argument parser
parser = argparse.ArgumentParser()
parser.add_argument('-e', '--epochs', type=int, default=200,
help='number of epochs to train the model for')
args = vars(parser.parse_args())
def save_decoded_image(img, name):
img = img.view(img.size(0), 3, 512, 512)
save_image(img, name)
# helper functions
image_dir = '../outputs/saved_images_new6'
os.makedirs(image_dir, exist_ok=True)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)
batch_size = 8
# gauss_blur = os.listdir('../input/gaussian_blurred')
gauss_blur = os.listdir('../train/blur')
gauss_blur.sort()
# sharp = os.listdir('../input/sharp')
sharp = os.listdir('../train/original')
sharp.sort()
gauss_blur_val = os.listdir('../val/blur')
gauss_blur_val.sort()
sharp_val = os.listdir('../val/original')
sharp_val.sort()
x_blur = []
for i in range(len(gauss_blur)):
x_blur.append(gauss_blur[i])
y_sharp = []
for i in range(len(sharp)):
y_sharp.append(sharp[i])
x_blur_val = []
for i in range(len(gauss_blur_val)):
x_blur_val.append(gauss_blur_val[i])
y_sharp_val = []
for i in range(len(sharp_val)):
y_sharp_val.append(sharp_val[i])
x_train = x_blur
y_train = y_sharp
x_val = x_blur_val
y_val = y_sharp_val
print(len(x_train))
print(len(y_train))
print(len(x_val))
print(len(y_val))
# exit
# (x_train, x_val, y_train, y_val) = train_test_split(x_blur, y_sharp, test_size=0.25)
# define transforms
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((512, 512)),
transforms.ToTensor(),
])
class DeblurDataset(Dataset):
def __init__(self, blur_paths, sharp_paths=None, transforms=None):
self.X = blur_paths
self.y = sharp_paths
self.transforms = transforms
def __len__(self):
return (len(self.X))
def __getitem__(self, i):
# blur_image = cv2.imread(f"../input/gaussian_blurred/{self.X[i]}")
blur_image = cv2.imread(f"../train/blur/{self.X[i]}")
if self.transforms:
blur_image = self.transforms(blur_image)
if self.y is not None:
# sharp_image = cv2.imread(f"../input/sharp/{self.y[i]}")
sharp_image = cv2.imread(f"../train/original/{self.y[i]}")
sharp_image = self.transforms(sharp_image)
return (blur_image, sharp_image)
else:
return blur_image
class DeblurDatasetVal(Dataset):
def __init__(self, blur_paths, sharp_paths=None, transforms=None):
self.X = blur_paths
self.y = sharp_paths
self.transforms = transforms
def __len__(self):
return (len(self.X))
def __getitem__(self, i):
# blur_image = cv2.imread(f"../input/gaussian_blurred/{self.X[i]}")
blur_image = cv2.imread(f"../val/blur/{self.X[i]}")
if self.transforms:
blur_image = self.transforms(blur_image)
if self.y is not None:
# sharp_image = cv2.imread(f"../input/sharp/{self.y[i]}")
sharp_image = cv2.imread(f"../val/original/{self.y[i]}")
sharp_image = self.transforms(sharp_image)
return (blur_image, sharp_image)
else:
return blur_image
train_data = DeblurDataset(x_train, y_train, transform)
val_data = DeblurDatasetVal(x_val, y_val, transform)
trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valloader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
class DeblurCNN(nn.Module):
def __init__(self):
super(DeblurCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=2)
self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=2)
self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = self.conv3(x)
return x
model = DeblurCNN().to(device)
print(model)
# the loss function
# criterion = nn.MSELoss()
# criterion = structural_similarity
# criterion = pytorch_msssim.MSSSIM()
criterion = pytorch_msssim.msssim
criterion2 = nn.MSELoss()
criterion3 = nn.L1Loss()
# criterion4 = nn.L2Loss()
# kernel = gaussian_kernel(11, sigma=1.5),repeat(3,1,1)
# the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
patience=6,
factor=0.5,
verbose=True
)
def fit(model, dataloader, epoch):
model.train()
running_loss = 0.0
for i, data in tqdm(enumerate(dataloader), total=int(len(train_data)/dataloader.batch_size)):
blur_image = data[0]
sharp_image = data[1]
blur_image = blur_image.to(device)
sharp_image = sharp_image.to(device)
optimizer.zero_grad()
outputs = model(blur_image)
# loss = criterion(outputs, sharp_image)
loss1 = 1 - criterion(outputs, sharp_image, normalize="relu")
loss2 = criterion2(outputs, sharp_image)
loss3 = criterion3(outputs, sharp_image)
# loss4 = criterion4(outputs, sharp_image)
loss = loss1 + loss2 + loss3
# print(pytorch_msssim.msssim(outputs, sharp_image))
# outputsim = outputs.view(outputs.size(0), 3, 512, 512)
# sharp_imageim = sharp_image.view(sharp_image.size(0), 3, 512, 512)
# loss = 1 - criterion(outputs, sharp_image)
# backpropagation
loss.backward()
# update the parameters
optimizer.step()
running_loss += loss.item()
train_loss = running_loss/len(dataloader.dataset)
print(f"Train Loss: {train_loss:.5f}")
return train_loss
def validate(model, dataloader, epoch):
model.eval()
running_loss = 0.0
with torch.no_grad():
for i, data in tqdm(enumerate(dataloader), total=int(len(val_data)/dataloader.batch_size)):
# print(data[0].shape)
blur_image = data[0]
sharp_image = data[1]
blur_image = blur_image.to(device)
sharp_image = sharp_image.to(device)
outputs = model(blur_image)
# loss = criterion(outputs, sharp_image)
loss1 = 1 - criterion(outputs, sharp_image, normalize="relu")
loss2 = criterion2(outputs, sharp_image)
loss3 = criterion3(outputs, sharp_image)
# loss4 = criterion4(outputs, sharp_image)
loss = loss1 + loss2 + loss3
# outputsim = outputs.view(outputs.size(0), 3, 512, 512)
# sharp_imageim = sharp_image.view(sharp_image.size(0), 3, 512, 512)
# loss = 1 - criterion(outputs, sharp_image)
running_loss += loss.item()
if epoch == 0 and i == int((len(val_data)/dataloader.batch_size)-1):
save_decoded_image(sharp_image.cpu().data, name=f"../outputs/saved_images_new6/sharp{epoch}.jpg")
save_decoded_image(blur_image.cpu().data, name=f"../outputs/saved_images_new6/blur{epoch}.jpg")
if i == int((len(val_data)/dataloader.batch_size)-1):
save_decoded_image(outputs.cpu().data, name=f"../outputs/saved_images_new6/val_deblurred{epoch}.jpg")
val_loss = running_loss/len(dataloader.dataset)
print(f"Val Loss: {val_loss:.5f}")
return val_loss
train_loss = []
val_loss = []
start = time.time()
for epoch in range(args['epochs']):
print(f"Epoch {epoch+1} of {args['epochs']}")
train_epoch_loss = fit(model, trainloader, epoch)
val_epoch_loss = validate(model, valloader, epoch)
train_loss.append(train_epoch_loss)
val_loss.append(val_epoch_loss)
scheduler.step(val_epoch_loss)
end = time.time()
print(f"Took {((end-start)/60):.3f} minutes to train")
# loss plots
plt.figure(figsize=(10, 7))
plt.plot(train_loss, color='orange', label='train loss')
plt.plot(val_loss, color='red', label='validataion loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('../outputs/loss11.png')
plt.show()
# save the model to disk
print('Saving model...')
torch.save(model.state_dict(), '../outputs/model1epoch11.pth')
Inference Loop
In [ ]:
import numpy as np
import os
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import argparse
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torchvision.utils import save_image
from PIL import Image
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
def save_decoded_image(img, name):
img = img.view(img.size(0), 3, 512, 512)
save_image(img, name)
class DeblurCNN(nn.Module):
def __init__(self):
super(DeblurCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=2)
self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=2)
self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = self.conv3(x)
return x
model = DeblurCNN().to(device)
model.load_state_dict(torch.load('../outputs/model1epoch11.pth'))
# print("MODEL2")
# define transforms
transform = transforms.Compose([
# transforms.ToPILImage(),
# transforms.Resize((512, 512)),
transforms.ToTensor(),
])
blurs = os.listdir('../test/blur')
blurs.sort()
blurred_images = []
for i in range(len(blurs)):
blurred_images.append(blurs[i])
print(blurs[i])
img = Image.open(f"../test/blur/{blurs[i]}")
img = transform(img)
batch = torch.unsqueeze(img, 0).to(device)
output = model(batch)
save_decoded_image(output.cpu().data, f"../test/original7/{blurs[i]}")
# img = Image.open("../train/blur/00a35.jpg")
# # img = Image.open("../outputs/testingpurposes/00a35inf2.jpg")
# img_t = transform(img)
# batch_t = torch.unsqueeze(img_t, 0).to(device)
# model.eval()
# out = model(batch_t)
# print(out.shape)
# save_decoded_image(out.cpu().data, "../outputs/testingpurposes/00a35inf3.jpg")
Content
Comments
You must login before you can post a comment.