U-Net Architecture: Revolutionizing Computer Vision Through Innovative Image Segmentation

Lukman Aliyu
5 min readAug 28, 2023

--

Introduction

In the dynamic world of computer vision, where innovation is key, few architectures have made as profound an impact as U-Net. Published in 2015 by Olaf Ronneberger, Philipp Fischer, and Thomas Brox, the paper titled “U-Net: Convolutional Networks for Biomedical Image Segmentation” introduced a revolutionary approach to image segmentation. Its influence extends far beyond its biomedical origins, as U-Net’s innovative design and remarkable performance have shaped the landscape of computer vision. In this blog post, we’ll explore the motivations, architecture, and the far-reaching impact of the U-Net.

Motivation and Problem Statement

The crux of U-Net’s inception lay in biomedical image segmentation. The challenge was clear: accurately assign each pixel in an image to a specific class or category. Traditional methods struggled with capturing fine details and object boundaries, relying heavily on handcrafted features. U-Net aimed to change this paradigm.

Architecture Overview

Fig.1 : The Unet archictecture

U-Net’s architecture resembles the letter “U,” a design that would become iconic in the field. It comprises two primary components: the contracting path (left side) and the expansive path (right side).

  • Contracting Path: This part adopts a typical convolutional neural network (CNN) structure, with convolutional and pooling layers. These layers progressively reduce spatial dimensions while enhancing feature channels, enabling the network to learn hierarchical features.
  • Expansive Path: The expansive path upscales learned features using transpose convolutions (or deconvolutions). What sets U-Net apart is the fusion of these upsampled features with those from the contracting path. This mechanism, called skip connections, merges high-level contextual information with detailed spatial data, mitigating information loss.
  • Skip Connections: Skip connections emerged as a pivotal innovation in U-Net. They combat information loss during downsampling and upsampling, significantly improving accuracy, especially for capturing fine details and object boundaries.
  • Loss Function: U-Net uses a pixel-wise cross-entropy loss function. It quantifies the disparity between predicted segmentation and ground truth on a per-pixel basis, facilitating accurate pixel-level learning.
  • Data Augmentation: Data augmentation plays a vital role in U-Net’s robustness. Techniques like random cropping, elastic deformations, and rotations artificially enhance training data diversity.
  • Revolutionizing Impact: The U-Net architecture has left an indelible mark on computer vision, thanks to several key factors:
  • Efficient Semantic Segmentation: U-Net introduced an efficient architecture for semantic segmentation. Its ability to capture both global context and fine details via skip connections proved invaluable for tasks requiring precise object boundaries and details.
  • Skip Connections and Multi-scale Information: The integration of multi-scale information via skip connections elevated U-Net’s accuracy, surpassing traditional CNNs.
  • State-of-the-Art Performance: U-Net quickly gained acclaim for its outstanding performance in biomedical image segmentation. It consistently outperformed existing methods, setting new accuracy benchmarks.
  • Broad Applicability: While initially conceived for biomedical imaging, U-Net’s principles have transcended boundaries. It has been adapted for diverse computer vision tasks, including natural scene segmentation, satellite image analysis, self-driving cars and beyond.

PyTorch Implementation of U-Net

Now, let’s dive into a PyTorch implementation of the U-Net architecture, complete with explanations and code snippets.These were adapted from here

# importing relevant libraries
import torch
import torchvision.transforms.functional
from torch import nn

Padding will be used in this implementation to avoid cropping the final image. This is the key difference with the actual unet paper implementation.

# Defining the layers
class DoubleConvolution(nn.Module):
'''Each step in the contraction path and expansive path have two 3×3 convolutional layers followed by ReLU activations.
'''
def __init__(self,in_channels:int,out_channels:int):
super.__init__()
self.first = nn.Conv2d(in_channesls,out_channels,kernel_size=3,padding= 1)
self.act1 = nn.ReLU()
self.second = nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1)
self.act2 = nn.ReLU()

def forward(self,x:torch.Tensor):
# Apply the two convolution layers and activations
x = self.first(x)
x = self.act1(x)
x = self.second(x)
return self.act2(x)



class Downsample(nn.Module):
# Each step in the contracting path down-samples the feature map with a 2×2 max pooling layer.
def __init__(self):
super.__init__()
# max pooling layer
self.pool = nn.MaxPool2d(2)
def forward(self,x:torch.Tensor):
return self.pool(x)



class Upsample(nn.Module):
# Each step in the expansive path up-samples the feature map with a 2×2 up-convolution.
def __init__(self,in_channels:int,out_channels:int):
super.__init__()
# up convolution
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
def forward(self,x:torch.Tensor):
return self.up(x)



class CropAndConcat(nn.Module):
'''At every step in the expansive path the corresponding feature map from the contracting path is concatenated with the current feature map.'''
def forward(self, x: torch.Tensor, contracting_x: torch.Tensor):
# Crop the feature map from the contracting path to the size of the current feature map
contracting_x = torchvision.transforms.functional.center_crop(contracting_x, [x.shape[2], x.shape[3]])
# Concatenate the feature maps
x = torch.cat([x, contracting_x], dim=1)
return x
# The Unet Architecture
class UNet(nn.Module):
def __init__(self,in_channels: int,out_channels: int):
super().__init()
# Double convolution layers for the contracting path. The number of features gets doubled at each step starting from 64.
self.down_conv = nn.ModuleList([DoubleConvolution(i, o) for i, o in [(in_channels, 64), (64, 128), (128, 256), (256, 512)]])
# Down sampling layers for the contracting path
self.down_sample = nn.ModuleList([DownSample() for _ in range(4)])
# The two convolution layers at the lowest resolution (the bottom of the U).
self.middle_conv = DoubleConvolution(512, 1024)
# Up sampling layers for the expansive path. The number of features is halved with up-sampling.
self.up_sample = nn.ModuleList([UpSample(i, o) for i, o in [(1024, 512), (512, 256), (256, 128), (128, 64)]])
# Double convolution layers for the expansive path. Their input is the concatenation of the current feature map and the feature map from the contracting path.
# Therefore, the number of input features is double the number of features from up-sampling.
self.up_conv = nn.ModuleList([DoubleConvolution(i, o) for i, o in [(1024, 512), (512, 256), (256, 128), (128, 64)]])
# Crop and concatenate layers for the expansive path.
self.concat = nn.ModuleList([CropAndConcat() for _ in range(4)])
# Final 1×1 convolution layer to produce the output
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
def forward(self, x: torch.Tensor):
# To collect the outputs of contracting path for later concatenation with the expansive path.
pass_through = []
# contracting path
for i in range(len(self.down_conv)):
# Two 3×3 convolutional layers
# afford
x = self.down_conv[i](x)
# collect the output
pass_through.append(X)
# Down-sample
x = self.down_sample[i](x)
# Two 3×3 convolutional layers at the bottom of the U-Net
x = self.middle_conv(x)
for i in range(len(self.up_conv)):
x = self.up_sample[i](x)
# Concatenate the output of the contracting path
x = self.concat[i](x, pass_through.pop())
# Two 3×3 convolutional layers
x = self.up_conv[i](x)
# Final 1×1 convolution layer
x = self.final_conv(x)
return x

In conclusion, U-Net’s innovative architecture and exceptional performance have revolutionized computer vision, making it a cornerstone in the field. Its adaptability and versatility have expanded its influence far beyond its original domain, leaving an enduring legacy in the world of computer vision.

--

--

Lukman Aliyu

Pharmacist enthusiastic about Data Science/AI/ML| Fellow, Arewa Data Science Academy