Python U-Net医学影像分割实现详解

医学影像分割是医学图像处理中的一个重要任务,旨在从医学影像中自动分割出感兴趣的区域,如器官、病变等。这对于疾病的诊断、治疗规划和医学研究具有重要意义。U-Net 是一种专门用于医学影像分割的卷积神经网络,因其在处理医学影像时的高效性和准确性而被广泛应用。本文介绍如何使用 Python 和 PyTorch 实现 U-Net 模型进行医学影像分割。

二、数据集概述

1. 数据集来源

我们使用的是 DRIVE 数据集,这是一个公开的视网膜图像数据集,用于血管分割任务。数据集包含 40 张视网膜图像及其对应的分割掩码,其中 20 张用于训练,20 张用于测试。数据集可以从 DRIVE 数据集官网 下载。

2. 数据集特点

  • 图像尺寸:284×388 像素
  • 图像类型:彩色图像
  • 分割掩码:二值图像,白色表示血管,黑色表示背景
  • 3. 数据集结构

    数据集分为训练集和测试集,每个集合包含图像和对应的分割掩码。图像以 TIFF 格式存储,分割掩码以 GIF 格式存储。

    三、技术实现

    1. 环境准备

    在开始之前,确保已安装以下必要的 Python 库:

  • torch:深度学习框架,支持模型的训练和推理。
  • torchvision:提供了常用的计算机视觉工具和模型。
  • numpy:用于数值计算。
  • Pillow:用于图像处理。
  • 安装命令如下:

    pip install torch torchvision numpy pillow
    

    2. 数据加载与预处理

    2.1 数据加载

    我们使用 torchvisionDataset 类来加载数据集。以下是一个示例代码:

    import os
    from PIL import Image
    from torchvision import transforms
    from torch.utils.data import Dataset
    
    class DRIVEDataset(Dataset):
        def __init__(self, image_dir, mask_dir, transform=None):
            self.image_dir = image_dir
            self.mask_dir = mask_dir
            self.transform = transform
            self.image_files = os.listdir(image_dir)
    
        def __len__(self):
            return len(self.image_files)
    
        def __getitem__(self, idx):
            image_path = os.path.join(self.image_dir, self.image_files[idx])
            mask_path = os.path.join(self.mask_dir, self.image_files[idx].replace('.tif', '_mask.gif'))
    
            image = Image.open(image_path).convert('RGB')
            mask = Image.open(mask_path).convert('L')
    
            if self.transform:
                image = self.transform(image)
                mask = self.transform(mask)
    
            return image, mask
    
    2.2 数据预处理

    我们对图像和分割掩码进行以下预处理:

  • 调整图像大小为 256×256 像素
  • 将图像转换为张量
  • 对图像进行归一化处理
  • transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    dataset = DRIVEDataset(image_dir='./DRIVE/train/images', mask_dir='./DRIVE/train/masks', transform=transform)
    

    3. 模型构建

    3.1 U-Net 模型

    U-Net 模型由编码器和解码器组成,编码器用于提取特征,解码器用于生成分割掩码。以下是一个简单的 U-Net 模型实现:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    class UNet(nn.Module):
        def __init__(self, in_channels, num_classes):
            super(UNet, self).__init__()
            self.encoder = nn.Sequential(
                nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(64, 128, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2)
            )
            self.decoder = nn.Sequential(
                nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
                nn.ReLU(inplace=True),
                nn.Conv2d(64, num_classes, kernel_size=3, padding=1)
            )
    
        def forward(self, x):
            x = self.encoder(x)
            x = self.decoder(x)
            return x
    

    4. 模型训练

    4.1 配置训练参数
    from torch.utils.data import DataLoader
    from torch.optim import Adam
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = UNet(in_channels=3, num_classes=1).to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = Adam(model.parameters(), lr=0.001)
    
    train_loader = DataLoader(dataset, batch_size=8, shuffle=True)
    
    4.2 训练模型
    num_epochs = 10
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, masks in train_loader:
            images = images.to(device)
            masks = masks.to(device)
    
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
    
            running_loss += loss.item() * images.size(0)
    
        epoch_loss = running_loss / len(train_loader.dataset)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')
    

    5. 模型评估

    5.1 加载测试数据
    test_dataset = DRIVEDataset(image_dir='./DRIVE/test/images', mask_dir='./DRIVE/test/masks', transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)
    
    5.2 评估模型
    model.eval()
    with torch.no_grad():
        total_correct = 0
        total_pixels = 0
        for images, masks in test_loader:
            images = images.to(device)
            masks = masks.to(device)
    
            outputs = model(images)
            preds = torch.sigmoid(outputs) > 0.5
    
            total_correct += (preds == masks).sum().item()
            total_pixels += masks.numel()
    
        accuracy = total_correct / total_pixels
        print(f'Test Accuracy: {accuracy:.4f}')
    

    四、结果展示

    1. 分割效果

    以下是模型在测试集上的分割效果示例:

    import matplotlib.pyplot as plt
    
    def visualize_segmentation(image, mask, pred):
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 3, 1)
        plt.imshow(image)
        plt.title('Input Image')
        plt.subplot(1, 3, 2)
        plt.imshow(mask, cmap='gray')
        plt.title('Ground Truth Mask')
        plt.subplot(1, 3, 3)
        plt.imshow(pred, cmap='gray')
        plt.title('Predicted Mask')
        plt.show()
    
    # 可视化示例
    image, mask = test_dataset[0]
    output = model(image.unsqueeze(0).to(device))
    pred = torch.sigmoid(output) > 0.5
    visualize_segmentation(image.permute(1, 2, 0).numpy(), mask.numpy(), pred.squeeze(0).numpy())
    


    五、总结

    通过上述步骤,我们成功地使用 Python 和 PyTorch 实现了 U-Net 模型进行医学影像分割。U-Net 模型在医学影像分割任务中表现出色,能够准确地分割出感兴趣的区域。在实际应用中,可以根据具体需求对模型进行进一步的优化和改进,例如调整网络结构、增加数据增强等。

    作者:Solomon_肖哥弹架构

    物联沃分享整理
    物联沃-IOTWORD物联网 » Python U-Net医学影像分割实现详解

    发表回复