Loading

Face De-Blurring

Solution for submission 174153

A detailed solution for submission 174153 submitted for challenge Face De-Blurring

eren23

Training Loop

In [ ]:
from cv2 import repeat
import numpy as np
import os
import matplotlib.pyplot as plt
import glob
import cv2
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import argparse
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torchvision.utils import save_image
# from skimage.metrics import structural_similarity
import pytorch_msssim
# from piqa.ssim import ssim
# from piqa.utils.functional import gaussian_kernel



# constructing the argument parser
parser = argparse.ArgumentParser()
parser.add_argument('-e', '--epochs', type=int, default=200, 
            help='number of epochs to train the model for')
args = vars(parser.parse_args())
def save_decoded_image(img, name):
    img = img.view(img.size(0), 3, 512, 512)
    save_image(img, name)

# helper functions
image_dir = '../outputs/saved_images_new6'
os.makedirs(image_dir, exist_ok=True)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)
batch_size = 8

# gauss_blur = os.listdir('../input/gaussian_blurred')
gauss_blur = os.listdir('../train/blur')
gauss_blur.sort()

# sharp = os.listdir('../input/sharp')
sharp = os.listdir('../train/original')
sharp.sort()

gauss_blur_val = os.listdir('../val/blur')
gauss_blur_val.sort()

sharp_val = os.listdir('../val/original')
sharp_val.sort()

x_blur = []
for i in range(len(gauss_blur)):
    x_blur.append(gauss_blur[i])

y_sharp = []
for i in range(len(sharp)):
    y_sharp.append(sharp[i])

x_blur_val = []
for i in range(len(gauss_blur_val)):
    x_blur_val.append(gauss_blur_val[i])

y_sharp_val = []
for i in range(len(sharp_val)):
    y_sharp_val.append(sharp_val[i])


x_train = x_blur
y_train = y_sharp
x_val = x_blur_val
y_val = y_sharp_val

print(len(x_train))
print(len(y_train))
print(len(x_val))
print(len(y_val))

# exit

# (x_train, x_val, y_train, y_val) = train_test_split(x_blur, y_sharp, test_size=0.25)


# define transforms
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
])

class DeblurDataset(Dataset):
    def __init__(self, blur_paths, sharp_paths=None, transforms=None):
        self.X = blur_paths
        self.y = sharp_paths
        self.transforms = transforms
         
    def __len__(self):
        return (len(self.X))
    
    def __getitem__(self, i):
        # blur_image = cv2.imread(f"../input/gaussian_blurred/{self.X[i]}")
        blur_image = cv2.imread(f"../train/blur/{self.X[i]}")

        
        if self.transforms:
            blur_image = self.transforms(blur_image)
            
        if self.y is not None:
            # sharp_image = cv2.imread(f"../input/sharp/{self.y[i]}")
            sharp_image = cv2.imread(f"../train/original/{self.y[i]}")

            sharp_image = self.transforms(sharp_image)
            return (blur_image, sharp_image)
        else:
            return blur_image

class DeblurDatasetVal(Dataset):
    def __init__(self, blur_paths, sharp_paths=None, transforms=None):
        self.X = blur_paths
        self.y = sharp_paths
        self.transforms = transforms
         
    def __len__(self):
        return (len(self.X))
    
    def __getitem__(self, i):
        # blur_image = cv2.imread(f"../input/gaussian_blurred/{self.X[i]}")
        blur_image = cv2.imread(f"../val/blur/{self.X[i]}")

        
        if self.transforms:
            blur_image = self.transforms(blur_image)
            
        if self.y is not None:
            # sharp_image = cv2.imread(f"../input/sharp/{self.y[i]}")
            sharp_image = cv2.imread(f"../val/original/{self.y[i]}")

            sharp_image = self.transforms(sharp_image)
            return (blur_image, sharp_image)
        else:
            return blur_image


train_data = DeblurDataset(x_train, y_train, transform)
val_data = DeblurDatasetVal(x_val, y_val, transform)
 
trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valloader = DataLoader(val_data, batch_size=batch_size, shuffle=False)

class DeblurCNN(nn.Module):
    def __init__(self):
        super(DeblurCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=2)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=2)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return x
model = DeblurCNN().to(device)
print(model)

# the loss function
# criterion = nn.MSELoss()
# criterion = structural_similarity
# criterion = pytorch_msssim.MSSSIM()
criterion = pytorch_msssim.msssim
criterion2 = nn.MSELoss()
criterion3 = nn.L1Loss()
# criterion4 = nn.L2Loss()

# kernel = gaussian_kernel(11, sigma=1.5),repeat(3,1,1)


# the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 
        optimizer,
        mode='min',
        patience=6,
        factor=0.5,
        verbose=True
    )

def fit(model, dataloader, epoch):
    model.train()
    running_loss = 0.0
    for i, data in tqdm(enumerate(dataloader), total=int(len(train_data)/dataloader.batch_size)):
        blur_image = data[0]
        sharp_image = data[1]
        blur_image = blur_image.to(device)
        sharp_image = sharp_image.to(device)
        optimizer.zero_grad()
        outputs = model(blur_image)
        # loss = criterion(outputs, sharp_image)

        loss1 = 1 - criterion(outputs, sharp_image, normalize="relu")
        loss2 = criterion2(outputs, sharp_image)
        loss3 = criterion3(outputs, sharp_image)
        # loss4 = criterion4(outputs, sharp_image)
        loss = loss1 + loss2 + loss3
        # print(pytorch_msssim.msssim(outputs, sharp_image))
         
        # outputsim = outputs.view(outputs.size(0), 3, 512, 512)
        # sharp_imageim = sharp_image.view(sharp_image.size(0), 3, 512, 512)
        # loss = 1 - criterion(outputs, sharp_image)

        # backpropagation
        loss.backward()
        # update the parameters
        optimizer.step()
        running_loss += loss.item()
    
    train_loss = running_loss/len(dataloader.dataset)
    print(f"Train Loss: {train_loss:.5f}")
    
    return train_loss

def validate(model, dataloader, epoch):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloader), total=int(len(val_data)/dataloader.batch_size)):
            # print(data[0].shape)
            blur_image = data[0]
            sharp_image = data[1]
            blur_image = blur_image.to(device)
            sharp_image = sharp_image.to(device)
            outputs = model(blur_image)
            # loss = criterion(outputs, sharp_image)
            loss1 = 1 - criterion(outputs, sharp_image, normalize="relu")
            loss2 = criterion2(outputs, sharp_image)
            loss3 = criterion3(outputs, sharp_image)
            # loss4 = criterion4(outputs, sharp_image)
            loss = loss1 + loss2 + loss3
            
            # outputsim = outputs.view(outputs.size(0), 3, 512, 512)
            # sharp_imageim = sharp_image.view(sharp_image.size(0), 3, 512, 512)

            # loss = 1 - criterion(outputs, sharp_image)

            running_loss += loss.item()
            if epoch == 0 and i == int((len(val_data)/dataloader.batch_size)-1):
                save_decoded_image(sharp_image.cpu().data, name=f"../outputs/saved_images_new6/sharp{epoch}.jpg")
                save_decoded_image(blur_image.cpu().data, name=f"../outputs/saved_images_new6/blur{epoch}.jpg")
            if i == int((len(val_data)/dataloader.batch_size)-1):
                save_decoded_image(outputs.cpu().data, name=f"../outputs/saved_images_new6/val_deblurred{epoch}.jpg")
        val_loss = running_loss/len(dataloader.dataset)
        print(f"Val Loss: {val_loss:.5f}")
        
        return val_loss


train_loss  = []
val_loss = []
start = time.time()
for epoch in range(args['epochs']):
    print(f"Epoch {epoch+1} of {args['epochs']}")
    train_epoch_loss = fit(model, trainloader, epoch)
    val_epoch_loss = validate(model, valloader, epoch)
    train_loss.append(train_epoch_loss)
    val_loss.append(val_epoch_loss)
    scheduler.step(val_epoch_loss)
end = time.time()
print(f"Took {((end-start)/60):.3f} minutes to train")




# loss plots
plt.figure(figsize=(10, 7))
plt.plot(train_loss, color='orange', label='train loss')
plt.plot(val_loss, color='red', label='validataion loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('../outputs/loss11.png')
plt.show()
# save the model to disk
print('Saving model...')
torch.save(model.state_dict(), '../outputs/model1epoch11.pth')

Inference Loop

In [ ]:
import numpy as np
import os
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import argparse
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torchvision.utils import save_image
from PIL import Image

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

def save_decoded_image(img, name):
    img = img.view(img.size(0), 3, 512, 512)
    save_image(img, name)

class DeblurCNN(nn.Module):
    def __init__(self):
        super(DeblurCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=2)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=2)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return x

model = DeblurCNN().to(device)
model.load_state_dict(torch.load('../outputs/model1epoch11.pth'))
# print("MODEL2")

# define transforms
transform = transforms.Compose([
    # transforms.ToPILImage(),
    # transforms.Resize((512, 512)),
    transforms.ToTensor(),
])

blurs = os.listdir('../test/blur')
blurs.sort()

blurred_images = []

for i in range(len(blurs)):
    blurred_images.append(blurs[i])
    print(blurs[i])

    img = Image.open(f"../test/blur/{blurs[i]}")
    img = transform(img)
    batch = torch.unsqueeze(img, 0).to(device)
    output = model(batch)
    save_decoded_image(output.cpu().data, f"../test/original7/{blurs[i]}")


# img = Image.open("../train/blur/00a35.jpg")
# # img = Image.open("../outputs/testingpurposes/00a35inf2.jpg")


# img_t = transform(img)

# batch_t = torch.unsqueeze(img_t, 0).to(device)

# model.eval()

# out = model(batch_t)

# print(out.shape)

# save_decoded_image(out.cpu().data, "../outputs/testingpurposes/00a35inf3.jpg")

Comments

You must login before you can post a comment.

Execute