
from pathlib import Path
from urllib.request import Request, urlopen
import json
import shutil
import zipfile
RESOURCE_MANIFESTS = [
"https://assets.deeplearningnotes.com/code-support-resources/datasets/drive/latest.json",
]
def download_file(url, path):
request = Request(url, headers={"User-Agent": "Mozilla/5.0"})
with urlopen(request) as response, open(path, "wb") as file:
shutil.copyfileobj(response, file)
def extract_zip_safely(zip_path, target_dir="."):
target_dir = Path(target_dir).resolve()
with zipfile.ZipFile(zip_path, "r") as zip_ref:
for member in zip_ref.infolist():
target_path = (target_dir / member.filename).resolve()
if not str(target_path).startswith(str(target_dir)):
raise RuntimeError(f"Unsafe zip path: {member.filename}")
zip_ref.extractall(target_dir)
for manifest_url in RESOURCE_MANIFESTS:
name = manifest_url.rstrip("/").split("/")[-2]
manifest_path = Path(f"{name}-latest.json")
archive_path = Path(f"{name}.zip")
download_file(manifest_url, manifest_path)
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
expected_paths = [Path(path) for path in manifest.get("expected_paths", [])]
if not all(path.exists() for path in expected_paths):
download_file(manifest["archive_url"], archive_path)
extract_zip_safely(archive_path, manifest.get("extract_to", "."))
missing = [str(path) for path in expected_paths if not path.exists()]
if missing:
raise FileNotFoundError(f"Missing expected paths: {missing}")
print(f"{manifest['name']} ready.")Imports
Libraries for retina vessel segmentation with a U-Net style model.
import copy
import random
import time
import re
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from torchvision.transforms import functional as TF
print("PyTorch version:", torch.__version__)Hyperparameters
Configure dataset paths, crop-based training, and transfer-vs-scratch mode.
hyperparameters = {
'dataset_roots': ['./datasets/DRIVE',],
'image_size': 224,
'train_crop_size': 224,
'batch_size': 8,
'num_epochs': 45,
'learning_rate': 1e-4,
'weight_decay': 1e-5,
'num_workers': 2,
'persistent_workers': True,
'train_valid_split': 0.8,
'random_seed': 42,
# GPU selection (default: use GPU 1)
'gpu_id': 1,
# 'transfer' -> ResNet50 encoder initialized from ImageNet
# 'scratch' -> all layers randomly initialized
'training_mode': 'scratch',
'encoder_name': 'resnet50',
# random crops per image for one training epoch
'train_samples_per_image': 20,
# sliding window parameters for full-image test segmentation
'inference_tile_size': 224,
'inference_stride': 112,
'inference_blend': 'hann', # options: uniform, hann
}
seed = hyperparameters['random_seed']
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
requested_gpu_id = hyperparameters.get('gpu_id', 1)
if torch.cuda.is_available():
n_gpus = torch.cuda.device_count()
if requested_gpu_id < n_gpus:
device = torch.device(f'cuda:{requested_gpu_id}')
else:
print(f"Requested gpu_id={requested_gpu_id}, but only {n_gpus} GPU(s) available. Falling back to cuda:0.")
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
print('Using device:', device)
print('Requested gpu_id:', requested_gpu_id)
print('Random seed:', seed)
def resolve_dataset_root(roots):
for root in roots:
p = Path(root)
if (p / 'train').exists() and (p / 'test').exists():
return p
raise FileNotFoundError(
f"Could not find DRIVE root in any of: {roots}. Expected train/ and test/ subfolders."
)
dataset_root = resolve_dataset_root(hyperparameters['dataset_roots'])
train_root = dataset_root / 'train'
test_root = dataset_root / 'test'
print('Using dataset root:', dataset_root.resolve())U-Net Model And Dice Loss
ResNet50-encoder U-Net, configurable in transfer learning or from-scratch mode.
def dice_coefficient_from_logits(logits, targets, threshold=0.5, eps=1e-6):
probs = torch.sigmoid(logits)
preds = (probs > threshold).float()
targets = targets.float()
intersection = (preds * targets).sum(dim=(1, 2, 3))
union = preds.sum(dim=(1, 2, 3)) + targets.sum(dim=(1, 2, 3))
dice = (2.0 * intersection + eps) / (union + eps)
return dice.mean()
class DiceLoss(nn.Module):
def __init__(self, eps=1e-6):
super().__init__()
self.eps = eps
def forward(self, logits, targets):
# V-Net Dice loss: denominator uses squared terms.
probs = torch.sigmoid(logits)
targets = targets.float()
intersection = (probs * targets).sum(dim=(1, 2, 3))
denominator = (probs.pow(2) + targets.pow(2)).sum(dim=(1, 2, 3))
dice = (2.0 * intersection + self.eps) / (denominator + self.eps)
return 1.0 - dice.mean()
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.block(x)
class UpBlock(nn.Module):
def __init__(self, in_ch, skip_ch, out_ch):
super().__init__()
self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
self.conv = ConvBlock(out_ch + skip_ch, out_ch)
def forward(self, x, skip):
x = self.up(x)
if x.shape[-2:] != skip.shape[-2:]:
x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
x = torch.cat([x, skip], dim=1)
return self.conv(x)
class ResNet50UNet(nn.Module):
def __init__(self, pretrained_encoder=True):
super().__init__()
weights = models.ResNet50_Weights.DEFAULT if pretrained_encoder else None
encoder = models.resnet50(weights=weights)
self.stem = nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu)
self.pool = encoder.maxpool
self.enc1 = encoder.layer1
self.enc2 = encoder.layer2
self.enc3 = encoder.layer3
self.enc4 = encoder.layer4
self.center = ConvBlock(2048, 1024)
self.up4 = UpBlock(1024, 1024, 512)
self.up3 = UpBlock(512, 512, 256)
self.up2 = UpBlock(256, 256, 128)
self.up1 = UpBlock(128, 64, 64)
self.head = nn.Sequential(
ConvBlock(64, 32),
nn.Conv2d(32, 1, kernel_size=1)
)
def forward(self, x):
x0 = self.stem(x)
x1 = self.pool(x0)
x1 = self.enc1(x1)
x2 = self.enc2(x1)
x3 = self.enc3(x2)
x4 = self.enc4(x3)
center = self.center(x4)
d4 = self.up4(center, x3)
d3 = self.up3(d4, x2)
d2 = self.up2(d3, x1)
d1 = self.up1(d2, x0)
logits = self.head(d1)
logits = F.interpolate(logits, size=x.shape[-2:], mode='bilinear', align_corners=False)
return logits
def build_model(hparams):
if hparams['encoder_name'].lower() != 'resnet50':
raise ValueError("Only encoder_name='resnet50' is implemented in this notebook.")
mode = hparams['training_mode'].lower()
if mode not in ('transfer', 'scratch'):
raise ValueError("training_mode must be either 'transfer' or 'scratch'.")
pretrained_encoder = mode == 'transfer'
model = ResNet50UNet(pretrained_encoder=pretrained_encoder)
return model
def count_trainable_parameters(model):
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total, trainableOptimization Setup
Dice loss, AdamW optimizer, and cosine learning-rate schedule.
criterion = DiceLoss().to(device)
def build_optimizer_and_scheduler(model, hparams):
optimizer = optim.AdamW(
model.parameters(),
lr=hparams['learning_rate'],
weight_decay=hparams.get('weight_decay', 0.0),
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=hparams['num_epochs'],
)
return optimizer, schedulerModel Summary and Trainable Parameters
Inspect the model and verify how many parameters are trainable under the selected transfer mode.
_model = globals().get('model', None)
if _model is None:
print('Model not found yet. Run the dataset/dataloader cell first, then re-run this cell.')
else:
total_params, trainable_params = count_trainable_parameters(_model)
print(f"Model: ResNet50 U-Net")
print(f"Training mode: {hyperparameters['training_mode']}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")Build DRIVE Splits
Load train/test from DRIVE, then split train into 80% train and 20% validation.
def extract_numeric_id(path: Path):
m = re.search(r'(\d+)', path.stem)
return int(m.group(1)) if m else None
def build_drive_pairs(image_dir: Path, gt_dir: Path):
image_paths = sorted([p for p in image_dir.glob('*.png')])
gt_paths = sorted([p for p in gt_dir.glob('*') if p.suffix.lower() in {'.gif', '.png', '.tif', '.tiff', '.jpg', '.jpeg'}])
gt_by_id = {}
for g in gt_paths:
gid = extract_numeric_id(g)
if gid is not None:
gt_by_id[gid] = g
pairs = []
missing = []
for img in image_paths:
iid = extract_numeric_id(img)
if iid is None or iid not in gt_by_id:
missing.append(img.name)
continue
pairs.append((img, gt_by_id[iid]))
if missing:
print('Warning: images without matching ground truth:', missing)
if len(pairs) == 0:
raise RuntimeError(
f'No image/mask pairs found. Checked images in {image_dir} and masks in {gt_dir}.'
)
return sorted(pairs, key=lambda x: extract_numeric_id(x[0]))
class DriveTrainRandomCropDataset(Dataset):
def __init__(self, pairs, crop_size=224, samples_per_image=6, augment=True):
self.pairs = list(pairs)
self.crop_size = crop_size
self.samples_per_image = samples_per_image
self.augment = augment
def __len__(self):
return len(self.pairs) * self.samples_per_image
def __getitem__(self, idx):
pair_idx = idx % len(self.pairs)
img_path, mask_path = self.pairs[pair_idx]
image = Image.open(img_path).convert('RGB')
mask = Image.open(mask_path).convert('L')
image = TF.to_tensor(image)
mask = TF.to_tensor(mask)
mask = (mask > 0.5).float()
h, w = image.shape[-2:]
crop = self.crop_size
if h < crop or w < crop:
pad_h = max(0, crop - h)
pad_w = max(0, crop - w)
image = F.pad(image, (0, pad_w, 0, pad_h), mode='constant', value=0.0)
mask = F.pad(mask, (0, pad_w, 0, pad_h), mode='constant', value=0.0)
h, w = image.shape[-2:]
top = random.randint(0, h - crop)
left = random.randint(0, w - crop)
image = image[:, top:top + crop, left:left + crop]
mask = mask[:, top:top + crop, left:left + crop]
if self.augment:
if random.random() < 0.5:
image = TF.hflip(image)
mask = TF.hflip(mask)
if random.random() < 0.5:
image = TF.vflip(image)
mask = TF.vflip(mask)
return image, mask
class DriveFullImageDataset(Dataset):
def __init__(self, pairs):
self.pairs = list(pairs)
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
img_path, mask_path = self.pairs[idx]
image = TF.to_tensor(Image.open(img_path).convert('RGB'))
mask = TF.to_tensor(Image.open(mask_path).convert('L'))
mask = (mask > 0.5).float()
sample_id = img_path.stem
return image, mask, sample_id
train_gt_root = dataset_root / 'train_groundtruth'
test_gt_root = dataset_root / 'test_groundtruth'
all_train_pairs = build_drive_pairs(train_root, train_gt_root)
test_pairs = build_drive_pairs(test_root, test_gt_root)
n_train_all = len(all_train_pairs)
indices = list(range(n_train_all))
random.Random(seed).shuffle(indices)
split = int(n_train_all * hyperparameters['train_valid_split'])
train_idx = indices[:split]
valid_idx = indices[split:]
train_pairs = [all_train_pairs[i] for i in train_idx]
valid_pairs = [all_train_pairs[i] for i in valid_idx]
print(f'Train images total: {n_train_all}')
print(f'Train split: {len(train_pairs)}')
print(f'Valid split: {len(valid_pairs)}')
print(f'Test images: {len(test_pairs)}')Pairing Sanity Check
Visualize image/mask pairs to verify correct correspondence.
def show_image_mask_pairs(pairs, max_items=4):
n = min(max_items, len(pairs))
plt.figure(figsize=(10, 3 * n))
for i in range(n):
img_path, mask_path = pairs[i]
image = np.array(Image.open(img_path).convert('RGB'))
mask = np.array(Image.open(mask_path).convert('L'))
plt.subplot(n, 2, 2 * i + 1)
plt.imshow(image)
plt.title(f'Image: {img_path.name}')
plt.axis('off')
plt.subplot(n, 2, 2 * i + 2)
plt.imshow(mask, cmap='gray')
plt.title(f'Mask: {mask_path.name}')
plt.axis('off')
plt.tight_layout()
plt.show()
show_image_mask_pairs(train_pairs, max_items=4)Create Segmentation Datasets
Random-crop training dataset and full-image validation/test datasets.
train_dataset = DriveTrainRandomCropDataset(
pairs=train_pairs,
crop_size=hyperparameters['train_crop_size'],
samples_per_image=hyperparameters['train_samples_per_image'],
augment=True,
)
valid_dataset = DriveFullImageDataset(valid_pairs)
test_dataset = DriveFullImageDataset(test_pairs)
print('Datasets ready:')
print(' train samples per epoch:', len(train_dataset))
print(' valid full images:', len(valid_dataset))
print(' test full images:', len(test_dataset))DataLoaders And Model
Build loaders and instantiate U-Net according to selected mode.
loader_num_workers = hyperparameters.get('num_workers', 2)
use_persistent_workers = hyperparameters.get('persistent_workers', True) and loader_num_workers > 0
pin_memory = torch.cuda.is_available()
train_loader = DataLoader(
train_dataset,
batch_size=hyperparameters['batch_size'],
shuffle=True,
num_workers=loader_num_workers,
pin_memory=pin_memory,
persistent_workers=use_persistent_workers,
)
# full-image validation/testing are done one image at a time
valid_loader = DataLoader(
valid_dataset,
batch_size=1,
shuffle=False,
num_workers=loader_num_workers,
pin_memory=pin_memory,
persistent_workers=use_persistent_workers,
)
test_loader = DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=loader_num_workers,
pin_memory=pin_memory,
persistent_workers=use_persistent_workers,
)
model = build_model(hyperparameters).to(device)
optimizer, scheduler = build_optimizer_and_scheduler(model, hyperparameters)
total_params, trainable_params = count_trainable_parameters(model)
print(f"Model: ResNet50 U-Net")
print(f"Mode: {hyperparameters['training_mode']}")
print(f"Trainable parameters: {trainable_params:,} / {total_params:,}")Crop Visualization
Inspect random 224x224 train crops and their vessel masks.
def show_train_crops(images, masks, max_items=4):
n = min(max_items, images.shape[0])
plt.figure(figsize=(8, 3 * n))
for i in range(n):
img = images[i].detach().cpu().permute(1, 2, 0).numpy()
msk = masks[i, 0].detach().cpu().numpy()
plt.subplot(n, 2, 2 * i + 1)
plt.imshow(img)
plt.title('Random train crop (224x224)')
plt.axis('off')
plt.subplot(n, 2, 2 * i + 2)
plt.imshow(msk, cmap='gray')
plt.title('Crop mask')
plt.axis('off')
plt.tight_layout()
plt.show()
images, masks = next(iter(train_loader))
print(f'Batch image tensor: {images.shape}')
print(f'Batch mask tensor: {masks.shape}')
show_train_crops(images, masks, max_items=4)Training Step
One training epoch on random crop samples.
def train_one_epoch(model, train_loader, criterion, optimizer):
model.train()
running_loss = 0.0
running_dice = 0.0
n_samples = 0
for images, masks in train_loader:
images = images.to(device, non_blocking=True)
masks = masks.to(device, non_blocking=True)
optimizer.zero_grad(set_to_none=True)
logits = model(images)
loss = criterion(logits, masks)
loss.backward()
optimizer.step()
batch_size = images.size(0)
running_loss += loss.item() * batch_size
running_dice += dice_coefficient_from_logits(logits, masks).item() * batch_size
n_samples += batch_size
avg_loss = running_loss / max(1, n_samples)
avg_dice = running_dice / max(1, n_samples)
return avg_loss, avg_diceFull-Image Evaluation
Sliding-window inference to segment full retina images and compute Dice.
def _sliding_positions(length, tile_size, stride):
if length <= tile_size:
return [0]
positions = list(range(0, length - tile_size + 1, stride))
if positions[-1] != length - tile_size:
positions.append(length - tile_size)
return positions
def _build_patch_weight(tile_size, mode='uniform', eps=1e-6, device='cpu'):
if mode == 'uniform':
w = torch.ones((tile_size, tile_size), device=device)
elif mode == 'hann':
h = torch.hann_window(tile_size, periodic=False, device=device)
w = torch.outer(h, h)
w = w + eps
else:
raise ValueError("blend mode must be one of: uniform, hann")
return w.view(1, 1, tile_size, tile_size)
def sliding_window_logits(model, image, tile_size=224, stride=112, blend='uniform', return_count_map=False):
# image shape: [1, 3, H, W]
_, _, h, w = image.shape
if h < tile_size or w < tile_size:
pad_h = max(0, tile_size - h)
pad_w = max(0, tile_size - w)
image = F.pad(image, (0, pad_w, 0, pad_h), mode='constant', value=0.0)
_, _, hp, wp = image.shape
y_positions = _sliding_positions(hp, tile_size, stride)
x_positions = _sliding_positions(wp, tile_size, stride)
logits_sum = torch.zeros((1, 1, hp, wp), device=image.device)
count_map = torch.zeros((1, 1, hp, wp), device=image.device)
patch_weight = _build_patch_weight(tile_size, mode=blend, device=image.device)
for y in y_positions:
for x in x_positions:
patch = image[:, :, y:y + tile_size, x:x + tile_size]
patch_logits = model(patch)
logits_sum[:, :, y:y + tile_size, x:x + tile_size] += patch_logits * patch_weight
count_map[:, :, y:y + tile_size, x:x + tile_size] += patch_weight
merged_logits = logits_sum / count_map.clamp_min(1e-6)
merged_logits = merged_logits[:, :, :h, :w]
if return_count_map:
return merged_logits, count_map[:, :, :h, :w]
return merged_logits
def evaluate_full_images(model, data_loader, criterion, tile_size=224, stride=112, threshold=0.5, blend='uniform'):
model.eval()
running_loss = 0.0
running_dice = 0.0
n_images = 0
total_intersection = 0.0
total_union = 0.0
with torch.no_grad():
for images, masks, _ in data_loader:
images = images.to(device, non_blocking=True)
masks = masks.to(device, non_blocking=True)
logits = sliding_window_logits(
model,
images,
tile_size=tile_size,
stride=stride,
blend=blend,
)
loss = criterion(logits, masks)
probs = torch.sigmoid(logits)
preds = (probs > threshold).float()
intersection = (preds * masks).sum().item()
union = preds.sum().item() + masks.sum().item()
total_intersection += intersection
total_union += union
running_loss += loss.item()
running_dice += ((2.0 * intersection + 1e-6) / (union + 1e-6))
n_images += 1
avg_loss = running_loss / max(1, n_images)
mean_image_dice = running_dice / max(1, n_images)
overall_dice = (2.0 * total_intersection + 1e-6) / (total_union + 1e-6)
return avg_loss, mean_image_dice, overall_dice
def diagnose_sliding_window_artifacts(model, image, tile_size=224, stride=112):
model.eval()
with torch.no_grad():
logits_uniform, count_map_uniform = sliding_window_logits(
model,
image,
tile_size=tile_size,
stride=stride,
blend='uniform',
return_count_map=True,
)
logits_hann, count_map_hann = sliding_window_logits(
model,
image,
tile_size=tile_size,
stride=stride,
blend='hann',
return_count_map=True,
)
probs_uniform = torch.sigmoid(logits_uniform)[0, 0].cpu().numpy()
probs_hann = torch.sigmoid(logits_hann)[0, 0].cpu().numpy()
cm_u = count_map_uniform[0, 0].cpu().numpy()
cm_h = count_map_hann[0, 0].cpu().numpy()
return {
'uniform_probs': probs_uniform,
'hann_probs': probs_hann,
'uniform_count_map': cm_u,
'hann_count_map': cm_h,
'delta_abs': np.abs(probs_hann - probs_uniform),
}Train And Test
Train on random crops, validate on full images, and report test Dice metrics.
train_losses = []
valid_losses = []
train_dices = []
valid_mean_dices = []
valid_overall_dices = []
best_val_dice = -1.0
best_epoch = 0
best_model_state = copy.deepcopy(model.state_dict())
for epoch in range(hyperparameters['num_epochs']):
start = time.time()
train_loss, train_dice = train_one_epoch(model, train_loader, criterion, optimizer)
val_loss, val_mean_dice, val_overall_dice = evaluate_full_images(
model,
valid_loader,
criterion,
tile_size=hyperparameters['inference_tile_size'],
stride=hyperparameters['inference_stride'],
blend=hyperparameters.get('inference_blend', 'uniform'),
)
scheduler.step()
lr = optimizer.param_groups[0]['lr']
train_losses.append(train_loss)
valid_losses.append(val_loss)
train_dices.append(train_dice)
valid_mean_dices.append(val_mean_dice)
valid_overall_dices.append(val_overall_dice)
elapsed = time.time() - start
print(
f"Epoch {epoch+1:02d}/{hyperparameters['num_epochs']} | "
f"{elapsed:.1f}s | lr={lr:.2e} | "
f"train loss={train_loss:.4f} dice={train_dice:.4f} | "
f"val loss={val_loss:.4f} mean-dice={val_mean_dice:.4f} overall-dice={val_overall_dice:.4f}"
)
if val_overall_dice > best_val_dice:
best_val_dice = val_overall_dice
best_epoch = epoch + 1
best_model_state = copy.deepcopy(model.state_dict())
model.load_state_dict(best_model_state)
test_loss, test_mean_dice, test_overall_dice = evaluate_full_images(
model,
test_loader,
criterion,
tile_size=hyperparameters['inference_tile_size'],
stride=hyperparameters['inference_stride'],
blend=hyperparameters.get('inference_blend', 'uniform'),
)
# Pixel-level test accuracy computed only inside FOV (non-black image pixels)
model.eval()
correct_pixels = 0
total_pixels = 0
with torch.no_grad():
for images, masks, _ in test_loader:
images = images.to(device, non_blocking=True)
masks = masks.to(device, non_blocking=True)
logits = sliding_window_logits(
model,
images,
tile_size=hyperparameters['inference_tile_size'],
stride=hyperparameters['inference_stride'],
blend=hyperparameters.get('inference_blend', 'uniform'),
)
preds = (torch.sigmoid(logits) > 0.5).float()
targets = (masks > 0.5).float()
# FOV = pixels where at least one RGB channel is non-black
fov_mask = (images.sum(dim=1, keepdim=True) > 0).float()
correct_pixels += ((preds == targets).float() * fov_mask).sum().item()
total_pixels += fov_mask.sum().item()
test_accuracy = correct_pixels / max(1.0, total_pixels)
print(f'Best validation overall Dice: {best_val_dice:.4f} at epoch {best_epoch}')
print(f'Test loss: {test_loss:.4f}')
print(f'Test mean-image Dice: {test_mean_dice:.4f}')
print(f'Test overall Dice: {test_overall_dice:.4f}')
print(f'Test pixel accuracy (FOV only): {test_accuracy:.4f}')
epochs = np.arange(1, len(train_losses) + 1)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs, train_losses, marker='o', label='Train Dice loss')
plt.plot(epochs, valid_losses, marker='o', label='Valid Dice loss')
plt.title('Loss per epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(epochs, train_dices, marker='o', label='Train Dice')
plt.plot(epochs, valid_mean_dices, marker='o', label='Valid mean-image Dice')
plt.plot(epochs, valid_overall_dices, marker='o', label='Valid overall Dice')
plt.title('Dice per epoch')
plt.xlabel('Epoch')
plt.ylabel('Dice')
plt.legend()
plt.tight_layout()
plt.show()Qualitative Results
Visual comparison of full-image predictions against ground truth masks.
def show_test_predictions(model, data_loader, max_items=4, threshold=0.5):
model.eval()
shown = 0
plt.figure(figsize=(16, 4 * max_items))
with torch.no_grad():
for images, masks, sample_ids in data_loader:
images = images.to(device)
masks = masks.to(device)
logits_uniform = sliding_window_logits(
model,
images,
tile_size=hyperparameters['inference_tile_size'],
stride=hyperparameters['inference_stride'],
blend='uniform',
)
logits_blend = sliding_window_logits(
model,
images,
tile_size=hyperparameters['inference_tile_size'],
stride=hyperparameters['inference_stride'],
blend=hyperparameters.get('inference_blend', 'uniform'),
)
pred_uniform = (torch.sigmoid(logits_uniform) > threshold).float()
pred_blend = (torch.sigmoid(logits_blend) > threshold).float()
inter_u = (pred_uniform * masks).sum().item()
union_u = pred_uniform.sum().item() + masks.sum().item()
dice_u = (2.0 * inter_u + 1e-6) / (union_u + 1e-6)
inter_b = (pred_blend * masks).sum().item()
union_b = pred_blend.sum().item() + masks.sum().item()
dice_b = (2.0 * inter_b + 1e-6) / (union_b + 1e-6)
image_np = images[0].cpu().permute(1, 2, 0).numpy()
gt_np = masks[0, 0].cpu().numpy()
uniform_np = pred_uniform[0, 0].cpu().numpy()
blend_np = pred_blend[0, 0].cpu().numpy()
row = shown
plt.subplot(max_items, 4, 4 * row + 1)
plt.imshow(image_np)
plt.title(f'{sample_ids[0]} - image')
plt.axis('off')
plt.subplot(max_items, 4, 4 * row + 2)
plt.imshow(gt_np, cmap='gray')
plt.title('ground truth')
plt.axis('off')
plt.subplot(max_items, 4, 4 * row + 3)
plt.imshow(uniform_np, cmap='gray')
plt.title(f'uniform merge (Dice={dice_u:.3f})')
plt.axis('off')
plt.subplot(max_items, 4, 4 * row + 4)
plt.imshow(blend_np, cmap='gray')
plt.title(f"{hyperparameters.get('inference_blend', 'uniform')} merge (Dice={dice_b:.3f})")
plt.axis('off')
shown += 1
if shown >= max_items:
break
plt.tight_layout()
plt.show()
show_test_predictions(model, test_loader, max_items=4)