UNetPlusPlus Image Segmentation Code Analysis#
Training Code and Explanation#
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from my_dataset import ImageSegmentationDataset # Custom dataset
from NestedUNet import NestedUNet # Model definition file
# Define hyperparameters
batch_size = 1
learning_rate = 1e-4
num_epochs = 200
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# Calculate new dimensions, dividing original dimensions by 2
new_height = 2048 // 2
new_width = 3072 // 2
# Data preprocessing and data augmentation
transform = transforms.Compose([
transforms.Resize((new_height, new_width)), # Resize image to half of original size
transforms.ToTensor() # Convert to PyTorch tensor
])
# Load data
train_dataset = ImageSegmentationDataset(image_dir='./dataset/train/images',
mask_dir='./dataset/train/masks',
transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# Initialize model, loss function, optimizer
model = NestedUNet(num_classes=2, input_channels=3).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Training loop
for epoch in range(num_epochs):
model.train()
epoch_loss = 0.0
for images, masks in train_loader:
images, masks = images.to(device), masks.to(device)
# Ensure target tensor shape is [batch_size, height, width]
masks = torch.squeeze(masks, dim=1) # Remove channel dimension
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
# Save the trained model
torch.save(model.state_dict(), './model.pth')
1. Data Preprocessing and Loading#
transform = transforms.Compose([
transforms.Resize((new_height, new_width)), # Resize image to half of original size
transforms.ToTensor() # Convert to PyTorch tensor
])
- Resize: Resizes images and masks to the new dimensions
(new_height, new_width)
, which is a reduction of the original size(2048, 3072)
. - ToTensor: Converts images and masks to PyTorch tensors and normalizes pixel values to the range [0, 1].
train_dataset = ImageSegmentationDataset(image_dir='./dataset/train/images',
mask_dir='./dataset/train/masks',
transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
- ImageSegmentationDataset: Custom dataset class responsible for loading images and their corresponding masks.
- DataLoader: Wraps the dataset into an iterable DataLoader, setting batch size and shuffle.
2. Model, Loss Function, and Optimizer Initialization#
model = NestedUNet(num_classes=2, input_channels=3).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
- NestedUNet: Custom neural network model for image segmentation, with 3 input channels (RGB images) and 2 output classes.
- CrossEntropyLoss: Loss function suitable for multi-class classification tasks, commonly used in image segmentation.
- Adam Optimizer: Used to update network parameters.
3. Training Loop#
for epoch in range(num_epochs):
model.train()
epoch_loss = 0.0
for images, masks in train_loader:
images, masks = images.to(device), masks.to(device)
# Ensure target tensor shape is [batch_size, height, width]
masks = torch.squeeze(masks, dim=1) # Remove channel dimension
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
- model.train(): Sets the model to training mode, enabling dropout and batch normalization.
- images, masks = images.to(device), masks.to(device): Transfers data to GPU or CPU.
- masks = torch.squeeze(masks, dim=1): This is a key step, explained below.
4. Detailed Explanation of Channel Processing#
In image segmentation tasks:
- Input images are typically three-dimensional, with a shape of
[batch_size, channels, height, width]
, e.g.,[1, 3, 1024, 1536]
. - Masks are usually four-dimensional but have a channel count of 1, with a shape of
[batch_size, 1, height, width]
, e.g.,[1, 1, 1024, 1536]
.
However, the nn.CrossEntropyLoss
function requires the target mask shape to be [batch_size, height, width]
, meaning it should not include the channel dimension.
Thus, the torch.squeeze
function is used to remove the channel dimension from the mask:
masks = torch.squeeze(masks, dim=1)
This changes the mask shape from [batch_size, 1, height, width]
to [batch_size, height, width]
, meeting the requirements of the loss function.
5. Model Output and Loss Calculation#
- outputs = model(images): The model output shape is
[batch_size, num_classes, height, width]
, e.g.,[1, 2, 1024, 1536]
. - loss = criterion(outputs, masks): Computes the cross-entropy loss between the predicted results and the true masks.
6. Model Saving#
torch.save(model.state_dict(), './model.pth')
- Saves the model parameters to the file
model.pth
, allowing for later loading and inference.
The main function of this code is to load a pre-trained NestedUNet model, use it to segment images in a specified directory, and save the results to an output directory. The execution flow of the code is as follows:
Inference Code and Explanation#
import os
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from NestedUNet import NestedUNet # Model definition file
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load model
def load_model(model, path):
if not os.path.exists(path):
raise FileNotFoundError(f"Model file not found: {path}")
model.load_state_dict(torch.load(path, map_location=device))
model.eval()
return model
# Perform inference
def segment_images(model, image_dir, output_dir):
# Calculate new dimensions, dividing original dimensions by 2
new_height = 2048 // 2
new_width = 3072 // 2
# Data preprocessing and data augmentation
transform = transforms.Compose([
transforms.Resize((new_height, new_width)), # Resize image to half of original size
transforms.ToTensor() # Convert to PyTorch tensor
])
os.makedirs(output_dir, exist_ok=True)
for filename in os.listdir(image_dir):
if filename.endswith(('.png', '.jpg', '.jpeg')):
filepath = os.path.join(image_dir, filename)
image = Image.open(filepath).convert('RGB')
input_tensor = transform(image).unsqueeze(0).to(device) # Add batch dimension
with torch.no_grad():
outputs = model(input_tensor)
prediction = torch.argmax(outputs, dim=1).squeeze(0) # Get segmentation result
# Save segmentation result
output_filename = filename.split('.')[0] + '_segmentation.png'
output_path = os.path.join(output_dir, output_filename)
# Map class values to the range 0-255
pred_img = prediction.cpu().numpy().astype(np.uint8) * 255
Image.fromarray(pred_img).save(output_path)
# Main execution code
if __name__ == "__main__":
model = NestedUNet(num_classes=2, input_channels=3).to(device)
model = load_model(model, './model.pth') # Load pre-trained model
# Define input and output directories
input_dirs = [
'./dataset/1-2000',
'./dataset/2001-4000',
'./dataset/4001-6000',
'./dataset/6001-8000',
'./dataset/8001-9663'
]
base_output_dir = './dataset/segmentation_results' # Base output results directory
for input_dir in input_dirs:
output_dir = os.path.join(base_output_dir, os.path.basename(input_dir))
segment_images(model, input_dir, output_dir)
print(f"Segmentation results saved to: {base_output_dir}")
1. Device Selection#
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
- Selects the computing device based on whether a GPU is available (checked via
torch.cuda.is_available()
). If a GPU is available, the code will use it; otherwise, it will use the CPU.
2. Load Model#
def load_model(model, path):
if not os.path.exists(path):
raise FileNotFoundError(f"Model file not found: {path}")
model.load_state_dict(torch.load(path, map_location=device))
model.eval()
return model
load_model
function:- Checks if the model file exists at the specified path.
- Loads the pre-trained model parameters using
torch.load()
. - After loading, calls
model.eval()
to set the model to evaluation mode (disabling dropout and other operations). - This function returns the model with loaded weights.
3. Perform Image Segmentation Inference#
def segment_images(model, image_dir, output_dir):
new_height = 2048 // 2
new_width = 3072 // 2
- Sets the target image size, reducing the original image height and width by half (
2048 // 2
and3072 // 2
).
Data Preprocessing#
transform = transforms.Compose([
transforms.Resize((new_height, new_width)), # Resize image
transforms.ToTensor() # Convert to PyTorch tensor
])
- Images are resized to the new dimensions through the
Resize
transformation. - Then converted to PyTorch tensor format using
ToTensor()
, making it suitable for model input.
Processing Each Image#
for filename in os.listdir(image_dir):
if filename.endswith(('.png', '.jpg', '.jpeg')):
filepath = os.path.join(image_dir, filename)
image = Image.open(filepath).convert('RGB')
input_tensor = transform(image).unsqueeze(0).to(device) # Add batch dimension
- Iterates through all image files in the
image_dir
directory (supporting.png
,.jpg
, and.jpeg
formats). - Reads each image and converts it to RGB mode (even grayscale images will be processed as RGB).
- Uses the preprocessing
transform
to convert it to a tensor and adds a batch dimension (unsqueeze(0)
), making the shape[1, C, H, W]
(suitable for model input).
Model Inference#
with torch.no_grad():
outputs = model(input_tensor)
prediction = torch.argmax(outputs, dim=1).squeeze(0) # Get segmentation result
- Uses
torch.no_grad()
to disable gradient calculation, saving memory and speeding up inference. model(input_tensor)
returns the model's output (class probabilities for each pixel).torch.argmax(outputs, dim=1)
: For each pixel, takes the class with the highest probability as the predicted class.squeeze(0)
: Removes the batch dimension, resulting inprediction
with a shape of[H, W]
.
Save Segmentation Result#
output_filename = filename.split('.')[0] + '_segmentation.png'
output_path = os.path.join(output_dir, output_filename)
pred_img = prediction.cpu().numpy().astype(np.uint8) * 255
Image.fromarray(pred_img).save(output_path)
output_filename
: Names each output image file, formatted as the original filename plus_segmentation.png
.prediction.cpu().numpy()
: Moves the prediction result from GPU to CPU and converts it to a NumPy array.astype(np.uint8) * 255
: Maps the predicted classes (0 or 1) to grayscale values (0 or 255), allowing the result to be saved as a black-and-white image.- Uses Pillow to save
pred_img
as a PNG file.
4. Main Execution Code#
if __name__ == "__main__":
model = NestedUNet(num_classes=2, input_channels=3).to(device)
model = load_model(model, './model.pth') # Load pre-trained model
input_dirs = [
'./dataset/1-2000',
'./dataset/2001-4000',
'./dataset/4001-6000',
'./dataset/6001-8000',
'./dataset/8001-9663'
]
base_output_dir = './dataset/segmentation_results' # Base output results directory
for input_dir in input_dirs:
output_dir = os.path.join(base_output_dir, os.path.basename(input_dir))
segment_images(model, input_dir, output_dir)
print(f"Segmentation results saved to: {base_output_dir}")
- In the main program, the
NestedUNet
model is first loaded and its weights are loaded. - A list of multiple subfolder paths (
input_dirs
) is defined, each containing images to be segmented. - For each input folder, a corresponding output folder is generated to save the segmentation results.
- Finally, the path of the saved results directory is printed.
Data Preprocessing Code and Explanation#
import os
import numpy as np
import torch
from PIL import Image
class ImageSegmentationDataset:
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.image_files = sorted(os.listdir(image_dir)) # Get list of image files and sort
def __getitem__(self, idx):
# Get image filename
image_file = self.image_files[idx]
image_path = os.path.join(self.image_dir, image_file)
# Construct mask filename, assuming mask files end with "_mask"
mask_file = image_file.replace(".jpg", "_mask.png")
mask_path = os.path.join(self.mask_dir, mask_file)
# Load image and mask
image = Image.open(image_path).convert('RGB')
mask = Image.open(mask_path).convert('L') # Grayscale image
# If there is a transform (data augmentation, etc.), apply it
if self.transform:
image = self.transform(image)
mask = self.transform(mask)
mask = torch.tensor(np.array(mask, dtype=np.int64))
return image, mask
def __len__(self):
# Return the number of image files in the dataset
return len(self.image_files)
__init__
Constructor#
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.image_files = sorted(os.listdir(image_dir)) # Get list of image files and sort
image_dir
: Directory path where images are stored.mask_dir
: Directory path where mask images are stored. Each image will have a corresponding mask image that annotates the target area.transform
: If there are data preprocessing or augmentation operations, they can be passed totransform
. For example, resizing, normalization, etc.image_files
: Retrieves all filenames inimage_dir
and sorts them to ensure the order of images matches the order of masks.
__getitem__
Method#
def __getitem__(self, idx):
# Get image filename
image_file = self.image_files[idx]
image_path = os.path.join(self.image_dir, image_file)
# Construct mask filename, assuming mask files end with "_mask"
mask_file = image_file.replace(".jpg", "_mask.png")
mask_path = os.path.join(self.mask_dir, mask_file)
# Load image and mask
image = Image.open(image_path).convert('RGB')
mask = Image.open(mask_path).convert('L') # Grayscale image
-
idx
: The index passed in, indicating which image and its corresponding mask to load from the dataset. -
image_file
: Gets the current image filename based onidx
. -
image_path
: Constructs the full path for the image based on its filename. -
mask_file
: Assumes that the mask image has the same filename as the original image, with_mask
appended before the file extension (assuming the original file is.jpg
and the mask is.png
). This rule can be modified as needed. -
mask_path
: Constructs the full path for the mask image based on its filename. -
Load image and mask:
- Uses
Pillow
'sImage.open()
to load the image and ensures it is in RGB format using.convert('RGB')
. - The mask is loaded as a grayscale image using
.convert('L')
.
- Uses
Apply Preprocessing Operations#
if self.transform:
image = self.transform(image)
mask = self.transform(mask)
- If a
transform
is provided (for example, data augmentation or preprocessing operations), it applies that transformation to both the image and the mask. Typically, resizing, normalization, and data augmentation are performed here.
Convert Mask to PyTorch Tensor#
mask = torch.tensor(np.array(mask, dtype=np.int64))
- Converts the mask image from a
Pillow
image object to a NumPy array. - Then converts the NumPy array to a PyTorch tensor with a type of
int64
. Usingint64
is common because the labels in segmentation tasks are typically integer types (e.g., each pixel corresponds to a class ID).
__len__
Method#
def __len__(self):
# Return the number of image files in the dataset
return len(self.image_files)
- This method returns the number of image files in the dataset. The PyTorch dataset class needs to implement this method to know the size of the dataset.
This code implements a deep learning model called Nested U-Net, primarily used for image segmentation tasks. Nested U-Net is an improved structure based on the traditional U-Net, which enhances segmentation accuracy by adding nested skip connections. Below, I will explain each part of the code in detail, especially the role of each module.
Nested U-Net#
VGGBlock#
class VGGBlock(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels):
super().__init__()
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(middle_channels)
self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
return out
- VGGBlock is a core convolutional block in the model. Each block consists of:
- Convolutional Layers:
conv1
andconv2
, both using a3x3
kernel with padding=1 to ensure the output size matches the input. - Batch Normalization Layers:
bn1
andbn2
, used to accelerate training and stabilize the model. - ReLU Activation Function: Increases the non-linearity of the model.
- Convolutional Layers:
This block is repeatedly called to form the basis of U-Net and Nested U-Net.
NestedUNet#
class NestedUNet(nn.Module):
def __init__(self, num_classes=2, input_channels=2, deep_supervision=False, **kwargs):
super().__init__()
nb_filter = [32, 64, 128, 256, 512]
self.deep_supervision = deep_supervision
self.pool = nn.MaxPool2d(2, 2)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
# Define convolution modules for each layer
self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
# Define nested convolution modules (i.e., skip connections)
self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])
self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])
# Final output layer, supporting deep supervision
if self.deep_supervision:
self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
else:
self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
Main Modules:#
-
Convolutional Layers: Each layer is composed of
VGGBlock
. The number of output channels gradually increases (32, 64, 128, 256, 512), and then in the subsequent nested layers, they are fused through skip connections. -
Skip Connections: This design is key to Nested U-Net, where the output of each layer is not only used for the next layer but also concatenated with outputs from other layers. This design helps retain more detailed information and improves segmentation accuracy.
-
Upsampling: Uses
Upsample
to increase the image size, and after skip connections, convolution operations are performed. -
Deep Supervision: By producing outputs at multiple stages, it enhances the learning effect of the model. This is a feature of Nested U-Net that allows the model to be supervised at different depths, improving performance.
Forward Method#
def forward(self, input):
# Various convolution operations
x0_0 = self.conv0_0(input)
x1_0 = self.conv1_0(self.pool(x0_0))
x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
x2_0 = self.conv2_0(self.pool(x1_0))
x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
# Continue nested connections and convolution operations until the last layer
x3_0 = self.conv3_0(self.pool(x2_0))
x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
x4_0 = self.conv4_0(self.pool(x3_0))
x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
if self.deep_supervision:
output1 = self.final1(x0_1)
output2 = self.final2(x0_2)
output3 = self.final3(x0_3)
output4 = self.final4(x0_4)
return [output1, output2, output3, output4]
else:
output = self.final(x0_4)
return output
- Convolution and Pooling Operations: Through
self.pool
, downsampling (pooling) is performed, and throughself.up
, upsampling (deconvolution) is done, concatenating outputs from different layers. - Deep Supervision Outputs: If deep supervision is enabled, results are produced at multiple intermediate layers; otherwise, only the final output is produced.