Tree Segmentation
[Getting Started Notebook] Trees Segmentation
A Getting Started notebook for Trees Segmentation Puzzle of BlitzX.
Starter Code for Trees Segmentation
What we are going to Learn¶
- Getting started Image Segmentation using PyTorch.
- Using models provided by segmentation_models.pytorch for the image segmentation.
- Training & Testing a Unet model with PyTorch
Note : Create a copy of the notebook and use the copy for submission. Go to File > Save a Copy in Drive to create a new copy
Setting up Environment¶
Downloading Dataset¶
So we will first need to download the python library by AIcrowd that will allow us to download the dataset by just inputting the API key.
!pip install aicrowd-cli
%load_ext aicrowd.magic
%aicrowd login
# Downloading the Dataset
!rm -rf data
!mkdir data
%aicrowd ds dl -c tree-segmentation -o data
!unzip data/train.zip -d data/train > /dev/null
!unzip data/test.zip -d data/test > /dev/null
Downloading & Importing Libraries¶
Here we are going to use segmentation_models.pytorch which is a really popular library providing a tons of different segmentation models for pytorch including basic unets to DeepLabV3!
Along with that, we will be also using library pytorch-argus to help in training the model.
!pip install git+https://github.com/qubvel/segmentation_models.pytorch pytorch-argus
# Pytorch
import torch
from torch import nn
import segmentation_models_pytorch as smp
import argus
from torch.utils.data import Dataset, DataLoader
# Reading Dataset, vis and miscellaneous
from PIL import Image
import matplotlib.pyplot as plt
import os
import numpy as np
from tqdm.notebook import tqdm
import cv2
from natsort import natsorted
Training phase ⚙️¶
Creating the Dataloader¶
Here, we are simply create a class for pytorch to load the dataset and then to put into the model
class TreeSegmentationDataset(Dataset):
def __init__(self, img_directory=None, label_directory=None, train=True):
self.img_directory = img_directory
self.label_directory = label_directory
# If the image direcotry is valid
if img_directory != None:
self.img_list = natsorted(os.listdir(img_directory))
self.label_list = natsorted(os.listdir(label_directory))
self.train = train
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
# Reading the image
img = Image.open(os.path.join(self.img_directory, self.img_list[idx]))
if self.train == True:
# Readiding the mak image
mask = Image.open(os.path.join(self.label_directory, self.label_list[idx]))
img = np.array(img, dtype=np.float32)
mask = np.array(mask, dtype=np.float32)
# Change image channel ordering
img = np.moveaxis(img, -1, 0)
return img, mask
# If reading test dataset, only return image
else:
img = np.array(img, dtype=np.float32)
img = np.moveaxis(img, -1, 0)
return img
# Creating the training dataset
train_dataset = TreeSegmentationDataset(img_directory="data/train/image", label_directory="data/train/segmentation")
train_loader = DataLoader(train_dataset, batch_size=4, num_workers=1, shuffle=False, drop_last=True)
# Reading the image and corrosponding segmentation
image_batch, segmentation_batch = next(iter(train_loader))
image_batch.shape, segmentation_batch.shape
Visualizing Dataset¶
plt.rcParams["figure.figsize"] = (30,5)
# Going through each image and segmentation
for image, segmentation in zip(image_batch, segmentation_batch):
# Change the channel ordering
image = np.moveaxis(image.numpy()/255, 0, -1)
# Showing the image
plt.figure()
plt.subplot(1,2,1)
plt.imshow(image, 'gray', interpolation='none')
plt.subplot(1,2,2)
plt.imshow(image, 'gray', interpolation='none')
plt.imshow(segmentation, 'jet', interpolation='none', alpha=0.7)
plt.show()
Creating the Model¶
Here we will get setting up the model architecture, optimizer and loss.
class TressSegmentationModel(argus.Model):
nn_module = smp.Unet
optimizer = torch.optim.Adam
loss = nn.MSELoss
model = TressSegmentationModel({ 'device': 'cuda',
'nn_module': {
'encoder_name': 'resnet18',
'classes': 1,
'in_channels': 3
}})
model
Training the Model¶
model.fit(train_loader,
num_epochs=1,
metrics=['accuracy'],
metrics_on_train=True)
Submitting Results 📄¶
Okay, this is the last section 😌 , let's get out testing results from the model real quick and submit our prediction directly using AIcrowd CLI
Loading the Test Dataset¶
test_dataset = TreeSegmentationDataset(img_directory="data/test/image", train=False)
test_loader = DataLoader(test_dataset, batch_size=4, num_workers=1, shuffle=False, drop_last=False)
Making the Predictions¶
predictions = []
# Ground though each test batch and adding predictions
for images in tqdm(test_loader):
prediction = model.predict(images)
predictions.extend(prediction.cpu().numpy())
# Change the channel ordering
image = np.moveaxis(images[0].numpy()/255, 0, -1)
# Showing the image
plt.figure()
plt.subplot(1,2,1)
plt.imshow(image, 'gray', interpolation='none')
plt.subplot(1,2,2)
plt.imshow(image, 'gray', interpolation='none')
plt.imshow(predictions[-1][0], 'jet', interpolation='none', alpha=0.7)
plt.show()
!rm -rf segmentation
!mkdir segmentation
for n, img in tqdm(enumerate(predictions)):
img = img[0]
_, img = cv2.threshold(img, 128, 255, cv2.THRESH_BINARY)
# Making sure the pixels are only 0 and 255 in the image.
img = Image.fromarray(img.astype(np.uint8))
img.save(os.path.join("segmentation", f"{n}.png"))
Note : Please make sure that there should be folder segmentation
your working directory before submitting the predictions
Uploading the Results¶
!aicrowd notebook submit -c tree-segmentation -a segmentation --no-verify
Don't be shy to ask question related to any errors you are getting or doubts in any part of this notebook in discussion forum or in AIcrowd Discord sever, AIcrew will be happy to help you :)
Also, wanna give us your valuable feedback for next blitz or wanna work with us creating blitz challanges ? Let us know!
Content
Comments
You must login before you can post a comment.
Hi
yo