banner
kanes

kanes

Image Segmentation Code Analysis

UNetPlusPlus Image Segmentation Code Analysis#

Training Code and Explanation#

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from my_dataset import ImageSegmentationDataset  # Custom dataset
from NestedUNet import NestedUNet  # Model definition file

# Define hyperparameters
batch_size = 1
learning_rate = 1e-4
num_epochs = 200

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Calculate new dimensions, dividing original dimensions by 2
new_height = 2048 // 2
new_width = 3072 // 2

# Data preprocessing and data augmentation
transform = transforms.Compose([
    transforms.Resize((new_height, new_width)),  # Resize image to half of original size
    transforms.ToTensor()  # Convert to PyTorch tensor
])

# Load data
train_dataset = ImageSegmentationDataset(image_dir='./dataset/train/images',
                                         mask_dir='./dataset/train/masks',
                                         transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Initialize model, loss function, optimizer
model = NestedUNet(num_classes=2, input_channels=3).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)

        # Ensure target tensor shape is [batch_size, height, width]
        masks = torch.squeeze(masks, dim=1)  # Remove channel dimension

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

# Save the trained model
torch.save(model.state_dict(), './model.pth')

1. Data Preprocessing and Loading#

transform = transforms.Compose([
    transforms.Resize((new_height, new_width)),  # Resize image to half of original size
    transforms.ToTensor()  # Convert to PyTorch tensor
])
  • Resize: Resizes images and masks to the new dimensions (new_height, new_width), which is a reduction of the original size (2048, 3072).
  • ToTensor: Converts images and masks to PyTorch tensors and normalizes pixel values to the range [0, 1].
train_dataset = ImageSegmentationDataset(image_dir='./dataset/train/images',
                                         mask_dir='./dataset/train/masks',
                                         transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  • ImageSegmentationDataset: Custom dataset class responsible for loading images and their corresponding masks.
  • DataLoader: Wraps the dataset into an iterable DataLoader, setting batch size and shuffle.

2. Model, Loss Function, and Optimizer Initialization#

model = NestedUNet(num_classes=2, input_channels=3).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
  • NestedUNet: Custom neural network model for image segmentation, with 3 input channels (RGB images) and 2 output classes.
  • CrossEntropyLoss: Loss function suitable for multi-class classification tasks, commonly used in image segmentation.
  • Adam Optimizer: Used to update network parameters.

3. Training Loop#

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)
        
        # Ensure target tensor shape is [batch_size, height, width]
        masks = torch.squeeze(masks, dim=1)  # Remove channel dimension
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
    
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
  • model.train(): Sets the model to training mode, enabling dropout and batch normalization.
  • images, masks = images.to(device), masks.to(device): Transfers data to GPU or CPU.
  • masks = torch.squeeze(masks, dim=1): This is a key step, explained below.

4. Detailed Explanation of Channel Processing#

In image segmentation tasks:

  • Input images are typically three-dimensional, with a shape of [batch_size, channels, height, width], e.g., [1, 3, 1024, 1536].
  • Masks are usually four-dimensional but have a channel count of 1, with a shape of [batch_size, 1, height, width], e.g., [1, 1, 1024, 1536].

However, the nn.CrossEntropyLoss function requires the target mask shape to be [batch_size, height, width], meaning it should not include the channel dimension.

Thus, the torch.squeeze function is used to remove the channel dimension from the mask:

masks = torch.squeeze(masks, dim=1)

This changes the mask shape from [batch_size, 1, height, width] to [batch_size, height, width], meeting the requirements of the loss function.

5. Model Output and Loss Calculation#

  • outputs = model(images): The model output shape is [batch_size, num_classes, height, width], e.g., [1, 2, 1024, 1536].
  • loss = criterion(outputs, masks): Computes the cross-entropy loss between the predicted results and the true masks.

6. Model Saving#

torch.save(model.state_dict(), './model.pth')
  • Saves the model parameters to the file model.pth, allowing for later loading and inference.

The main function of this code is to load a pre-trained NestedUNet model, use it to segment images in a specified directory, and save the results to an output directory. The execution flow of the code is as follows:


Inference Code and Explanation#

import os
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from NestedUNet import NestedUNet  # Model definition file

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load model
def load_model(model, path):
    if not os.path.exists(path):
        raise FileNotFoundError(f"Model file not found: {path}")
    model.load_state_dict(torch.load(path, map_location=device))
    model.eval()
    return model

# Perform inference
def segment_images(model, image_dir, output_dir):
    # Calculate new dimensions, dividing original dimensions by 2
    new_height = 2048 // 2
    new_width = 3072 // 2

    # Data preprocessing and data augmentation
    transform = transforms.Compose([
        transforms.Resize((new_height, new_width)),  # Resize image to half of original size
        transforms.ToTensor()  # Convert to PyTorch tensor
    ])

    os.makedirs(output_dir, exist_ok=True)

    for filename in os.listdir(image_dir):
        if filename.endswith(('.png', '.jpg', '.jpeg')):
            filepath = os.path.join(image_dir, filename)
            image = Image.open(filepath).convert('RGB')
            input_tensor = transform(image).unsqueeze(0).to(device)  # Add batch dimension

            with torch.no_grad():
                outputs = model(input_tensor)
                prediction = torch.argmax(outputs, dim=1).squeeze(0)  # Get segmentation result

            # Save segmentation result
            output_filename = filename.split('.')[0] + '_segmentation.png'
            output_path = os.path.join(output_dir, output_filename)

            # Map class values to the range 0-255
            pred_img = prediction.cpu().numpy().astype(np.uint8) * 255
            Image.fromarray(pred_img).save(output_path)

# Main execution code
if __name__ == "__main__":
    model = NestedUNet(num_classes=2, input_channels=3).to(device)
    model = load_model(model, './model.pth')  # Load pre-trained model

    # Define input and output directories
    input_dirs = [
        './dataset/1-2000',
        './dataset/2001-4000',
        './dataset/4001-6000',
        './dataset/6001-8000',
        './dataset/8001-9663'
    ]

    base_output_dir = './dataset/segmentation_results'  # Base output results directory

    for input_dir in input_dirs:
        output_dir = os.path.join(base_output_dir, os.path.basename(input_dir))
        segment_images(model, input_dir, output_dir)

    print(f"Segmentation results saved to: {base_output_dir}")

1. Device Selection#

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
  • Selects the computing device based on whether a GPU is available (checked via torch.cuda.is_available()). If a GPU is available, the code will use it; otherwise, it will use the CPU.

2. Load Model#

def load_model(model, path):
    if not os.path.exists(path):
        raise FileNotFoundError(f"Model file not found: {path}")
    model.load_state_dict(torch.load(path, map_location=device))
    model.eval()
    return model
  • load_model function:
    • Checks if the model file exists at the specified path.
    • Loads the pre-trained model parameters using torch.load().
    • After loading, calls model.eval() to set the model to evaluation mode (disabling dropout and other operations).
    • This function returns the model with loaded weights.

3. Perform Image Segmentation Inference#

def segment_images(model, image_dir, output_dir):
    new_height = 2048 // 2
    new_width = 3072 // 2
  • Sets the target image size, reducing the original image height and width by half (2048 // 2 and 3072 // 2).

Data Preprocessing#

transform = transforms.Compose([
    transforms.Resize((new_height, new_width)),  # Resize image
    transforms.ToTensor()  # Convert to PyTorch tensor
])
  • Images are resized to the new dimensions through the Resize transformation.
  • Then converted to PyTorch tensor format using ToTensor(), making it suitable for model input.

Processing Each Image#

for filename in os.listdir(image_dir):
    if filename.endswith(('.png', '.jpg', '.jpeg')):
        filepath = os.path.join(image_dir, filename)
        image = Image.open(filepath).convert('RGB')
        input_tensor = transform(image).unsqueeze(0).to(device)  # Add batch dimension
  • Iterates through all image files in the image_dir directory (supporting .png, .jpg, and .jpeg formats).
  • Reads each image and converts it to RGB mode (even grayscale images will be processed as RGB).
  • Uses the preprocessing transform to convert it to a tensor and adds a batch dimension (unsqueeze(0)), making the shape [1, C, H, W] (suitable for model input).

Model Inference#

with torch.no_grad():
    outputs = model(input_tensor)
    prediction = torch.argmax(outputs, dim=1).squeeze(0)  # Get segmentation result
  • Uses torch.no_grad() to disable gradient calculation, saving memory and speeding up inference.
  • model(input_tensor) returns the model's output (class probabilities for each pixel).
  • torch.argmax(outputs, dim=1): For each pixel, takes the class with the highest probability as the predicted class.
  • squeeze(0): Removes the batch dimension, resulting in prediction with a shape of [H, W].

Save Segmentation Result#

output_filename = filename.split('.')[0] + '_segmentation.png'
output_path = os.path.join(output_dir, output_filename)

pred_img = prediction.cpu().numpy().astype(np.uint8) * 255
Image.fromarray(pred_img).save(output_path)
  • output_filename: Names each output image file, formatted as the original filename plus _segmentation.png.
  • prediction.cpu().numpy(): Moves the prediction result from GPU to CPU and converts it to a NumPy array.
  • astype(np.uint8) * 255: Maps the predicted classes (0 or 1) to grayscale values (0 or 255), allowing the result to be saved as a black-and-white image.
  • Uses Pillow to save pred_img as a PNG file.

4. Main Execution Code#

if __name__ == "__main__":
    model = NestedUNet(num_classes=2, input_channels=3).to(device)
    model = load_model(model, './model.pth')  # Load pre-trained model

    input_dirs = [
        './dataset/1-2000',
        './dataset/2001-4000',
        './dataset/4001-6000',
        './dataset/6001-8000',
        './dataset/8001-9663'
    ]

    base_output_dir = './dataset/segmentation_results'  # Base output results directory

    for input_dir in input_dirs:
        output_dir = os.path.join(base_output_dir, os.path.basename(input_dir))
        segment_images(model, input_dir, output_dir)

    print(f"Segmentation results saved to: {base_output_dir}")
  • In the main program, the NestedUNet model is first loaded and its weights are loaded.
  • A list of multiple subfolder paths (input_dirs) is defined, each containing images to be segmented.
  • For each input folder, a corresponding output folder is generated to save the segmentation results.
  • Finally, the path of the saved results directory is printed.

Data Preprocessing Code and Explanation#

import os

import numpy as np
import torch
from PIL import Image


class ImageSegmentationDataset:
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_files = sorted(os.listdir(image_dir))  # Get list of image files and sort

    def __getitem__(self, idx):
        # Get image filename
        image_file = self.image_files[idx]
        image_path = os.path.join(self.image_dir, image_file)

        # Construct mask filename, assuming mask files end with "_mask"
        mask_file = image_file.replace(".jpg", "_mask.png")
        mask_path = os.path.join(self.mask_dir, mask_file)

        # Load image and mask
        image = Image.open(image_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')  # Grayscale image

        # If there is a transform (data augmentation, etc.), apply it
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        mask = torch.tensor(np.array(mask, dtype=np.int64))
        return image, mask

    def __len__(self):
        # Return the number of image files in the dataset
        return len(self.image_files)

__init__ Constructor#

def __init__(self, image_dir, mask_dir, transform=None):
    self.image_dir = image_dir
    self.mask_dir = mask_dir
    self.transform = transform
    self.image_files = sorted(os.listdir(image_dir))  # Get list of image files and sort
  • image_dir: Directory path where images are stored.
  • mask_dir: Directory path where mask images are stored. Each image will have a corresponding mask image that annotates the target area.
  • transform: If there are data preprocessing or augmentation operations, they can be passed to transform. For example, resizing, normalization, etc.
  • image_files: Retrieves all filenames in image_dir and sorts them to ensure the order of images matches the order of masks.

__getitem__ Method#

def __getitem__(self, idx):
    # Get image filename
    image_file = self.image_files[idx]
    image_path = os.path.join(self.image_dir, image_file)

    # Construct mask filename, assuming mask files end with "_mask"
    mask_file = image_file.replace(".jpg", "_mask.png")
    mask_path = os.path.join(self.mask_dir, mask_file)

    # Load image and mask
    image = Image.open(image_path).convert('RGB')
    mask = Image.open(mask_path).convert('L')  # Grayscale image
  • idx: The index passed in, indicating which image and its corresponding mask to load from the dataset.

  • image_file: Gets the current image filename based on idx.

  • image_path: Constructs the full path for the image based on its filename.

  • mask_file: Assumes that the mask image has the same filename as the original image, with _mask appended before the file extension (assuming the original file is .jpg and the mask is .png). This rule can be modified as needed.

  • mask_path: Constructs the full path for the mask image based on its filename.

  • Load image and mask:

    • Uses Pillow's Image.open() to load the image and ensures it is in RGB format using .convert('RGB').
    • The mask is loaded as a grayscale image using .convert('L').

Apply Preprocessing Operations#

if self.transform:
    image = self.transform(image)
    mask = self.transform(mask)
  • If a transform is provided (for example, data augmentation or preprocessing operations), it applies that transformation to both the image and the mask. Typically, resizing, normalization, and data augmentation are performed here.

Convert Mask to PyTorch Tensor#

mask = torch.tensor(np.array(mask, dtype=np.int64))
  • Converts the mask image from a Pillow image object to a NumPy array.
  • Then converts the NumPy array to a PyTorch tensor with a type of int64. Using int64 is common because the labels in segmentation tasks are typically integer types (e.g., each pixel corresponds to a class ID).

__len__ Method#

def __len__(self):
    # Return the number of image files in the dataset
    return len(self.image_files)
  • This method returns the number of image files in the dataset. The PyTorch dataset class needs to implement this method to know the size of the dataset.

This code implements a deep learning model called Nested U-Net, primarily used for image segmentation tasks. Nested U-Net is an improved structure based on the traditional U-Net, which enhances segmentation accuracy by adding nested skip connections. Below, I will explain each part of the code in detail, especially the role of each module.


Nested U-Net#

VGGBlock#
class VGGBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        return out
  • VGGBlock is a core convolutional block in the model. Each block consists of:
    • Convolutional Layers: conv1 and conv2, both using a 3x3 kernel with padding=1 to ensure the output size matches the input.
    • Batch Normalization Layers: bn1 and bn2, used to accelerate training and stabilize the model.
    • ReLU Activation Function: Increases the non-linearity of the model.

This block is repeatedly called to form the basis of U-Net and Nested U-Net.

NestedUNet#

class NestedUNet(nn.Module):
    def __init__(self, num_classes=2, input_channels=2, deep_supervision=False, **kwargs):
        super().__init__()

        nb_filter = [32, 64, 128, 256, 512]

        self.deep_supervision = deep_supervision
        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        # Define convolution modules for each layer
        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])

        # Define nested convolution modules (i.e., skip connections)
        self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])

        self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])

        self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])

        self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])

        # Final output layer, supporting deep supervision
        if self.deep_supervision:
            self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
        else:
            self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
Main Modules:#
  • Convolutional Layers: Each layer is composed of VGGBlock. The number of output channels gradually increases (32, 64, 128, 256, 512), and then in the subsequent nested layers, they are fused through skip connections.

  • Skip Connections: This design is key to Nested U-Net, where the output of each layer is not only used for the next layer but also concatenated with outputs from other layers. This design helps retain more detailed information and improves segmentation accuracy.

  • Upsampling: Uses Upsample to increase the image size, and after skip connections, convolution operations are performed.

  • Deep Supervision: By producing outputs at multiple stages, it enhances the learning effect of the model. This is a feature of Nested U-Net that allows the model to be supervised at different depths, improving performance.


Forward Method#

def forward(self, input):
    # Various convolution operations
    x0_0 = self.conv0_0(input)

    x1_0 = self.conv1_0(self.pool(x0_0))
    x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))

    x2_0 = self.conv2_0(self.pool(x1_0))
    x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
    x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))

    # Continue nested connections and convolution operations until the last layer
    x3_0 = self.conv3_0(self.pool(x2_0))
    x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
    x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
    x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))

    x4_0 = self.conv4_0(self.pool(x3_0))
    x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
    x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
    x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
    x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))

    if self.deep_supervision:
        output1 = self.final1(x0_1)
        output2 = self.final2(x0_2)
        output3 = self.final3(x0_3)
        output4 = self.final4(x0_4)
        return [output1, output2, output3, output4]
    else:
        output = self.final(x0_4)
        return output
  • Convolution and Pooling Operations: Through self.pool, downsampling (pooling) is performed, and through self.up, upsampling (deconvolution) is done, concatenating outputs from different layers.
  • Deep Supervision Outputs: If deep supervision is enabled, results are produced at multiple intermediate layers; otherwise, only the final output is produced.

Loading...
Ownership of this post data is guaranteed by blockchain and smart contracts to the creator alone.