from pathlib import Path
from urllib.request import Request, urlopen
import json
import shutil
import zipfile
 
RESOURCE_MANIFESTS = [
    "https://assets.deeplearningnotes.com/code-support-resources/datasets/drive/latest.json",
]
 
def download_file(url, path):
    request = Request(url, headers={"User-Agent": "Mozilla/5.0"})
    with urlopen(request) as response, open(path, "wb") as file:
        shutil.copyfileobj(response, file)
 
def extract_zip_safely(zip_path, target_dir="."):
    target_dir = Path(target_dir).resolve()
 
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        for member in zip_ref.infolist():
            target_path = (target_dir / member.filename).resolve()
            if not str(target_path).startswith(str(target_dir)):
                raise RuntimeError(f"Unsafe zip path: {member.filename}")
 
        zip_ref.extractall(target_dir)
 
for manifest_url in RESOURCE_MANIFESTS:
    name = manifest_url.rstrip("/").split("/")[-2]
    manifest_path = Path(f"{name}-latest.json")
    archive_path = Path(f"{name}.zip")
 
    download_file(manifest_url, manifest_path)
    manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
 
    expected_paths = [Path(path) for path in manifest.get("expected_paths", [])]
 
    if not all(path.exists() for path in expected_paths):
        download_file(manifest["archive_url"], archive_path)
        extract_zip_safely(archive_path, manifest.get("extract_to", "."))
 
    missing = [str(path) for path in expected_paths if not path.exists()]
    if missing:
        raise FileNotFoundError(f"Missing expected paths: {missing}")
 
    print(f"{manifest['name']} ready.")

Imports

Libraries for retina vessel segmentation with a U-Net style model.

import copy
import random
import time
import re
from pathlib import Path
 
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from torchvision.transforms import functional as TF
 
print("PyTorch version:", torch.__version__)

Hyperparameters

Configure dataset paths, crop-based training, and transfer-vs-scratch mode.

hyperparameters = {
    'dataset_roots': ['./datasets/DRIVE',],
    'image_size': 224,
    'train_crop_size': 224,
    'batch_size': 8,
    'num_epochs': 45,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'num_workers': 2,
    'persistent_workers': True,
    'train_valid_split': 0.8,
    'random_seed': 42,
 
    # GPU selection (default: use GPU 1)
    'gpu_id': 1,
 
    # 'transfer' -> ResNet50 encoder initialized from ImageNet
    # 'scratch'  -> all layers randomly initialized
    'training_mode': 'scratch',
    'encoder_name': 'resnet50',
 
    # random crops per image for one training epoch
    'train_samples_per_image': 20,
 
    # sliding window parameters for full-image test segmentation
    'inference_tile_size': 224,
    'inference_stride': 112,
    'inference_blend': 'hann',  # options: uniform, hann
}
 
seed = hyperparameters['random_seed']
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
 
requested_gpu_id = hyperparameters.get('gpu_id', 1)
if torch.cuda.is_available():
    n_gpus = torch.cuda.device_count()
    if requested_gpu_id < n_gpus:
        device = torch.device(f'cuda:{requested_gpu_id}')
    else:
        print(f"Requested gpu_id={requested_gpu_id}, but only {n_gpus} GPU(s) available. Falling back to cuda:0.")
        device = torch.device('cuda:0')
else:
    device = torch.device('cpu')
 
print('Using device:', device)
print('Requested gpu_id:', requested_gpu_id)
print('Random seed:', seed)
 
 
def resolve_dataset_root(roots):
    for root in roots:
        p = Path(root)
        if (p / 'train').exists() and (p / 'test').exists():
            return p
    raise FileNotFoundError(
        f"Could not find DRIVE root in any of: {roots}. Expected train/ and test/ subfolders."
    )
    
    
dataset_root = resolve_dataset_root(hyperparameters['dataset_roots'])
train_root = dataset_root / 'train'
test_root = dataset_root / 'test'
print('Using dataset root:', dataset_root.resolve())

U-Net Model And Dice Loss

ResNet50-encoder U-Net, configurable in transfer learning or from-scratch mode.

def dice_coefficient_from_logits(logits, targets, threshold=0.5, eps=1e-6):
    probs = torch.sigmoid(logits)
    preds = (probs > threshold).float()
    targets = targets.float()
 
    intersection = (preds * targets).sum(dim=(1, 2, 3))
    union = preds.sum(dim=(1, 2, 3)) + targets.sum(dim=(1, 2, 3))
    dice = (2.0 * intersection + eps) / (union + eps)
    return dice.mean()
 
 
class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps
 
    def forward(self, logits, targets):
        # V-Net Dice loss: denominator uses squared terms.
        probs = torch.sigmoid(logits)
        targets = targets.float()
 
        intersection = (probs * targets).sum(dim=(1, 2, 3))
        denominator = (probs.pow(2) + targets.pow(2)).sum(dim=(1, 2, 3))
        dice = (2.0 * intersection + self.eps) / (denominator + self.eps)
        return 1.0 - dice.mean()
        
        
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
 
    def forward(self, x):
        return self.block(x)
 
 
class UpBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
        self.conv = ConvBlock(out_ch + skip_ch, out_ch)
 
    def forward(self, x, skip):
        x = self.up(x)
        if x.shape[-2:] != skip.shape[-2:]:
            x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)
        
        
class ResNet50UNet(nn.Module):
    def __init__(self, pretrained_encoder=True):
        super().__init__()
        weights = models.ResNet50_Weights.DEFAULT if pretrained_encoder else None
        encoder = models.resnet50(weights=weights)
 
        self.stem = nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu)
        self.pool = encoder.maxpool
        self.enc1 = encoder.layer1
        self.enc2 = encoder.layer2
        self.enc3 = encoder.layer3
        self.enc4 = encoder.layer4
 
        self.center = ConvBlock(2048, 1024)
 
        self.up4 = UpBlock(1024, 1024, 512)
        self.up3 = UpBlock(512, 512, 256)
        self.up2 = UpBlock(256, 256, 128)
        self.up1 = UpBlock(128, 64, 64)
        
        self.head = nn.Sequential(
            ConvBlock(64, 32),
            nn.Conv2d(32, 1, kernel_size=1)
        )
 
    def forward(self, x):
        x0 = self.stem(x)
        x1 = self.pool(x0)
        x1 = self.enc1(x1)
        x2 = self.enc2(x1)
        x3 = self.enc3(x2)
        x4 = self.enc4(x3)
        
        center = self.center(x4)
        d4 = self.up4(center, x3)
        d3 = self.up3(d4, x2)
        d2 = self.up2(d3, x1)
        d1 = self.up1(d2, x0)
 
        logits = self.head(d1)
        logits = F.interpolate(logits, size=x.shape[-2:], mode='bilinear', align_corners=False)
        return logits
 
 
def build_model(hparams):
    if hparams['encoder_name'].lower() != 'resnet50':
        raise ValueError("Only encoder_name='resnet50' is implemented in this notebook.")
 
    mode = hparams['training_mode'].lower()
    if mode not in ('transfer', 'scratch'):
        raise ValueError("training_mode must be either 'transfer' or 'scratch'.")
 
    pretrained_encoder = mode == 'transfer'
    model = ResNet50UNet(pretrained_encoder=pretrained_encoder)
    return model
 
 
def count_trainable_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

Optimization Setup

Dice loss, AdamW optimizer, and cosine learning-rate schedule.

criterion = DiceLoss().to(device)
 
  
def build_optimizer_and_scheduler(model, hparams):
    optimizer = optim.AdamW(
        model.parameters(),
        lr=hparams['learning_rate'],
        weight_decay=hparams.get('weight_decay', 0.0),
    )
 
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=hparams['num_epochs'],
    )
    return optimizer, scheduler

Model Summary and Trainable Parameters

Inspect the model and verify how many parameters are trainable under the selected transfer mode.

_model = globals().get('model', None)
if _model is None:
    print('Model not found yet. Run the dataset/dataloader cell first, then re-run this cell.')
else:
    total_params, trainable_params = count_trainable_parameters(_model)
    print(f"Model: ResNet50 U-Net")
    print(f"Training mode: {hyperparameters['training_mode']}")
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

Build DRIVE Splits

Load train/test from DRIVE, then split train into 80% train and 20% validation.

def extract_numeric_id(path: Path):
    m = re.search(r'(\d+)', path.stem)
    return int(m.group(1)) if m else None
 
 
def build_drive_pairs(image_dir: Path, gt_dir: Path):
    image_paths = sorted([p for p in image_dir.glob('*.png')])
    gt_paths = sorted([p for p in gt_dir.glob('*') if p.suffix.lower() in {'.gif', '.png', '.tif', '.tiff', '.jpg', '.jpeg'}])
 
    gt_by_id = {}
    for g in gt_paths:
        gid = extract_numeric_id(g)
        if gid is not None:
            gt_by_id[gid] = g
            
    pairs = []
    missing = []
    for img in image_paths:
        iid = extract_numeric_id(img)
        if iid is None or iid not in gt_by_id:
            missing.append(img.name)
            continue
        pairs.append((img, gt_by_id[iid]))
 
    if missing:
        print('Warning: images without matching ground truth:', missing)
 
    if len(pairs) == 0:
        raise RuntimeError(
            f'No image/mask pairs found. Checked images in {image_dir} and masks in {gt_dir}.'
        )
 
    return sorted(pairs, key=lambda x: extract_numeric_id(x[0]))
 
 
class DriveTrainRandomCropDataset(Dataset):
    def __init__(self, pairs, crop_size=224, samples_per_image=6, augment=True):
        self.pairs = list(pairs)
        self.crop_size = crop_size
        self.samples_per_image = samples_per_image
        self.augment = augment
 
    def __len__(self):
        return len(self.pairs) * self.samples_per_image
 
    def __getitem__(self, idx):
        pair_idx = idx % len(self.pairs)
        img_path, mask_path = self.pairs[pair_idx]
 
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')
 
        image = TF.to_tensor(image)
        mask = TF.to_tensor(mask)
        mask = (mask > 0.5).float()
 
        h, w = image.shape[-2:]
        crop = self.crop_size
 
        if h < crop or w < crop:
            pad_h = max(0, crop - h)
            pad_w = max(0, crop - w)
            image = F.pad(image, (0, pad_w, 0, pad_h), mode='constant', value=0.0)
            mask = F.pad(mask, (0, pad_w, 0, pad_h), mode='constant', value=0.0)
            h, w = image.shape[-2:]
 
        top = random.randint(0, h - crop)
        left = random.randint(0, w - crop)
        image = image[:, top:top + crop, left:left + crop]
        mask = mask[:, top:top + crop, left:left + crop]
 
        if self.augment:
            if random.random() < 0.5:
                image = TF.hflip(image)
                mask = TF.hflip(mask)
            if random.random() < 0.5:
                image = TF.vflip(image)
                mask = TF.vflip(mask)
 
        return image, mask
 
 
class DriveFullImageDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = list(pairs)
 
    def __len__(self):
        return len(self.pairs)
 
    def __getitem__(self, idx):
        img_path, mask_path = self.pairs[idx]
 
        image = TF.to_tensor(Image.open(img_path).convert('RGB'))
        mask = TF.to_tensor(Image.open(mask_path).convert('L'))
        mask = (mask > 0.5).float()
 
        sample_id = img_path.stem
        return image, mask, sample_id
 
 
train_gt_root = dataset_root / 'train_groundtruth'
test_gt_root = dataset_root / 'test_groundtruth'
 
all_train_pairs = build_drive_pairs(train_root, train_gt_root)
test_pairs = build_drive_pairs(test_root, test_gt_root)
 
n_train_all = len(all_train_pairs)
indices = list(range(n_train_all))
random.Random(seed).shuffle(indices)
 
split = int(n_train_all * hyperparameters['train_valid_split'])
train_idx = indices[:split]
valid_idx = indices[split:]
 
train_pairs = [all_train_pairs[i] for i in train_idx]
valid_pairs = [all_train_pairs[i] for i in valid_idx]
 
print(f'Train images total: {n_train_all}')
print(f'Train split: {len(train_pairs)}')
print(f'Valid split: {len(valid_pairs)}')
print(f'Test images: {len(test_pairs)}')

Pairing Sanity Check

Visualize image/mask pairs to verify correct correspondence.

def show_image_mask_pairs(pairs, max_items=4):
    n = min(max_items, len(pairs))
    plt.figure(figsize=(10, 3 * n))
 
    for i in range(n):
        img_path, mask_path = pairs[i]
        image = np.array(Image.open(img_path).convert('RGB'))
        mask = np.array(Image.open(mask_path).convert('L'))
 
        plt.subplot(n, 2, 2 * i + 1)
        plt.imshow(image)
        plt.title(f'Image: {img_path.name}')
        plt.axis('off')
 
        plt.subplot(n, 2, 2 * i + 2)
        plt.imshow(mask, cmap='gray')
        plt.title(f'Mask: {mask_path.name}')
        plt.axis('off')
 
    plt.tight_layout()
    plt.show()
 
 
show_image_mask_pairs(train_pairs, max_items=4)

Create Segmentation Datasets

Random-crop training dataset and full-image validation/test datasets.

train_dataset = DriveTrainRandomCropDataset(
    pairs=train_pairs,
    crop_size=hyperparameters['train_crop_size'],
    samples_per_image=hyperparameters['train_samples_per_image'],
    augment=True,
)
 
valid_dataset = DriveFullImageDataset(valid_pairs)
test_dataset = DriveFullImageDataset(test_pairs)
 
print('Datasets ready:')
print('  train samples per epoch:', len(train_dataset))
print('  valid full images:', len(valid_dataset))
print('  test full images:', len(test_dataset))

DataLoaders And Model

Build loaders and instantiate U-Net according to selected mode.

loader_num_workers = hyperparameters.get('num_workers', 2)
use_persistent_workers = hyperparameters.get('persistent_workers', True) and loader_num_workers > 0
pin_memory = torch.cuda.is_available()
 
train_loader = DataLoader(
    train_dataset,
    batch_size=hyperparameters['batch_size'],
    shuffle=True,
    num_workers=loader_num_workers,
    pin_memory=pin_memory,
    persistent_workers=use_persistent_workers,
)
 
# full-image validation/testing are done one image at a time
valid_loader = DataLoader(
    valid_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=loader_num_workers,
    pin_memory=pin_memory,
    persistent_workers=use_persistent_workers,
)
 
test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=loader_num_workers,
    pin_memory=pin_memory,
    persistent_workers=use_persistent_workers,
)
 
model = build_model(hyperparameters).to(device)
optimizer, scheduler = build_optimizer_and_scheduler(model, hyperparameters)
 
total_params, trainable_params = count_trainable_parameters(model)
print(f"Model: ResNet50 U-Net")
print(f"Mode: {hyperparameters['training_mode']}")
print(f"Trainable parameters: {trainable_params:,} / {total_params:,}")

Crop Visualization

Inspect random 224x224 train crops and their vessel masks.

def show_train_crops(images, masks, max_items=4):
    n = min(max_items, images.shape[0])
    plt.figure(figsize=(8, 3 * n))
 
    for i in range(n):
        img = images[i].detach().cpu().permute(1, 2, 0).numpy()
        msk = masks[i, 0].detach().cpu().numpy()
 
        plt.subplot(n, 2, 2 * i + 1)
        plt.imshow(img)
        plt.title('Random train crop (224x224)')
        plt.axis('off')
 
        plt.subplot(n, 2, 2 * i + 2)
        plt.imshow(msk, cmap='gray')
        plt.title('Crop mask')
        plt.axis('off')
        
    plt.tight_layout()
    plt.show()
 
  
images, masks = next(iter(train_loader))
print(f'Batch image tensor: {images.shape}')
print(f'Batch mask tensor: {masks.shape}')
show_train_crops(images, masks, max_items=4)

Training Step

One training epoch on random crop samples.

def train_one_epoch(model, train_loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    running_dice = 0.0
    n_samples = 0
 
    for images, masks in train_loader:
        images = images.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)
 
        optimizer.zero_grad(set_to_none=True)
        logits = model(images)
        loss = criterion(logits, masks)
        loss.backward()
        optimizer.step()
 
        batch_size = images.size(0)
        running_loss += loss.item() * batch_size
        running_dice += dice_coefficient_from_logits(logits, masks).item() * batch_size
        n_samples += batch_size
 
    avg_loss = running_loss / max(1, n_samples)
    avg_dice = running_dice / max(1, n_samples)
    return avg_loss, avg_dice

Full-Image Evaluation

Sliding-window inference to segment full retina images and compute Dice.

def _sliding_positions(length, tile_size, stride):
    if length <= tile_size:
        return [0]
 
    positions = list(range(0, length - tile_size + 1, stride))
    if positions[-1] != length - tile_size:
        positions.append(length - tile_size)
    return positions
 
 
def _build_patch_weight(tile_size, mode='uniform', eps=1e-6, device='cpu'):
    if mode == 'uniform':
        w = torch.ones((tile_size, tile_size), device=device)
    elif mode == 'hann':
        h = torch.hann_window(tile_size, periodic=False, device=device)
        w = torch.outer(h, h)
        w = w + eps
    else:
        raise ValueError("blend mode must be one of: uniform, hann")
    return w.view(1, 1, tile_size, tile_size)
 
 
def sliding_window_logits(model, image, tile_size=224, stride=112, blend='uniform', return_count_map=False):
    # image shape: [1, 3, H, W]
    _, _, h, w = image.shape
 
    if h < tile_size or w < tile_size:
        pad_h = max(0, tile_size - h)
        pad_w = max(0, tile_size - w)
        image = F.pad(image, (0, pad_w, 0, pad_h), mode='constant', value=0.0)
 
    _, _, hp, wp = image.shape
    y_positions = _sliding_positions(hp, tile_size, stride)
    x_positions = _sliding_positions(wp, tile_size, stride)
 
    logits_sum = torch.zeros((1, 1, hp, wp), device=image.device)
    count_map = torch.zeros((1, 1, hp, wp), device=image.device)
    patch_weight = _build_patch_weight(tile_size, mode=blend, device=image.device)
 
    for y in y_positions:
        for x in x_positions:
            patch = image[:, :, y:y + tile_size, x:x + tile_size]
            patch_logits = model(patch)
            logits_sum[:, :, y:y + tile_size, x:x + tile_size] += patch_logits * patch_weight
            count_map[:, :, y:y + tile_size, x:x + tile_size] += patch_weight
 
    merged_logits = logits_sum / count_map.clamp_min(1e-6)
    merged_logits = merged_logits[:, :, :h, :w]
 
    if return_count_map:
        return merged_logits, count_map[:, :, :h, :w]
    return merged_logits
 
 
def evaluate_full_images(model, data_loader, criterion, tile_size=224, stride=112, threshold=0.5, blend='uniform'):
    model.eval()
 
    running_loss = 0.0
    running_dice = 0.0
    n_images = 0
 
    total_intersection = 0.0
    total_union = 0.0
 
    with torch.no_grad():
        for images, masks, _ in data_loader:
            images = images.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)
 
            logits = sliding_window_logits(
                model,
                images,
                tile_size=tile_size,
                stride=stride,
                blend=blend,
            )
            loss = criterion(logits, masks)
 
            probs = torch.sigmoid(logits)
            preds = (probs > threshold).float()
 
            intersection = (preds * masks).sum().item()
            union = preds.sum().item() + masks.sum().item()
 
            total_intersection += intersection
            total_union += union
 
            running_loss += loss.item()
            running_dice += ((2.0 * intersection + 1e-6) / (union + 1e-6))
            n_images += 1
 
    avg_loss = running_loss / max(1, n_images)
    mean_image_dice = running_dice / max(1, n_images)
    overall_dice = (2.0 * total_intersection + 1e-6) / (total_union + 1e-6)
    return avg_loss, mean_image_dice, overall_dice
 
 
def diagnose_sliding_window_artifacts(model, image, tile_size=224, stride=112):
    model.eval()
    with torch.no_grad():
        logits_uniform, count_map_uniform = sliding_window_logits(
            model,
            image,
            tile_size=tile_size,
            stride=stride,
            blend='uniform',
            return_count_map=True,
        )
        logits_hann, count_map_hann = sliding_window_logits(
            model,
            image,
            tile_size=tile_size,
            stride=stride,
            blend='hann',
            return_count_map=True,
        )
 
    probs_uniform = torch.sigmoid(logits_uniform)[0, 0].cpu().numpy()
    probs_hann = torch.sigmoid(logits_hann)[0, 0].cpu().numpy()
    cm_u = count_map_uniform[0, 0].cpu().numpy()
    cm_h = count_map_hann[0, 0].cpu().numpy()
 
    return {
        'uniform_probs': probs_uniform,
        'hann_probs': probs_hann,
        'uniform_count_map': cm_u,
        'hann_count_map': cm_h,
        'delta_abs': np.abs(probs_hann - probs_uniform),
    }

Train And Test

Train on random crops, validate on full images, and report test Dice metrics.

train_losses = []
valid_losses = []
train_dices = []
valid_mean_dices = []
valid_overall_dices = []
 
best_val_dice = -1.0
best_epoch = 0
best_model_state = copy.deepcopy(model.state_dict())
 
for epoch in range(hyperparameters['num_epochs']):
    start = time.time()
 
    train_loss, train_dice = train_one_epoch(model, train_loader, criterion, optimizer)
 
    val_loss, val_mean_dice, val_overall_dice = evaluate_full_images(
        model,
        valid_loader,
        criterion,
        tile_size=hyperparameters['inference_tile_size'],
        stride=hyperparameters['inference_stride'],
        blend=hyperparameters.get('inference_blend', 'uniform'),
    )
 
    scheduler.step()
    lr = optimizer.param_groups[0]['lr']
 
    train_losses.append(train_loss)
    valid_losses.append(val_loss)
    train_dices.append(train_dice)
    valid_mean_dices.append(val_mean_dice)
    valid_overall_dices.append(val_overall_dice)
 
    elapsed = time.time() - start
    print(
        f"Epoch {epoch+1:02d}/{hyperparameters['num_epochs']} | "
        f"{elapsed:.1f}s | lr={lr:.2e} | "
        f"train loss={train_loss:.4f} dice={train_dice:.4f} | "
        f"val loss={val_loss:.4f} mean-dice={val_mean_dice:.4f} overall-dice={val_overall_dice:.4f}"
    )
 
    if val_overall_dice > best_val_dice:
        best_val_dice = val_overall_dice
        best_epoch = epoch + 1
        best_model_state = copy.deepcopy(model.state_dict())
 
  
model.load_state_dict(best_model_state)
 
test_loss, test_mean_dice, test_overall_dice = evaluate_full_images(
    model,
    test_loader,
    criterion,
    tile_size=hyperparameters['inference_tile_size'],
    stride=hyperparameters['inference_stride'],
    blend=hyperparameters.get('inference_blend', 'uniform'),
)
 
# Pixel-level test accuracy computed only inside FOV (non-black image pixels)
model.eval()
correct_pixels = 0
total_pixels = 0
with torch.no_grad():
    for images, masks, _ in test_loader:
        images = images.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)
 
        logits = sliding_window_logits(
            model,
            images,
            tile_size=hyperparameters['inference_tile_size'],
            stride=hyperparameters['inference_stride'],
            blend=hyperparameters.get('inference_blend', 'uniform'),
        )
 
        preds = (torch.sigmoid(logits) > 0.5).float()
        targets = (masks > 0.5).float()
 
        # FOV = pixels where at least one RGB channel is non-black
        fov_mask = (images.sum(dim=1, keepdim=True) > 0).float()
        correct_pixels += ((preds == targets).float() * fov_mask).sum().item()
        total_pixels += fov_mask.sum().item()
 
test_accuracy = correct_pixels / max(1.0, total_pixels)
 
print(f'Best validation overall Dice: {best_val_dice:.4f} at epoch {best_epoch}')
print(f'Test loss: {test_loss:.4f}')
print(f'Test mean-image Dice: {test_mean_dice:.4f}')
print(f'Test overall Dice: {test_overall_dice:.4f}')
print(f'Test pixel accuracy (FOV only): {test_accuracy:.4f}')
 
epochs = np.arange(1, len(train_losses) + 1)
plt.figure(figsize=(12, 4))
 
plt.subplot(1, 2, 1)
plt.plot(epochs, train_losses, marker='o', label='Train Dice loss')
plt.plot(epochs, valid_losses, marker='o', label='Valid Dice loss')
plt.title('Loss per epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
 
plt.subplot(1, 2, 2)
plt.plot(epochs, train_dices, marker='o', label='Train Dice')
plt.plot(epochs, valid_mean_dices, marker='o', label='Valid mean-image Dice')
plt.plot(epochs, valid_overall_dices, marker='o', label='Valid overall Dice')
plt.title('Dice per epoch')
plt.xlabel('Epoch')
plt.ylabel('Dice')
plt.legend()
 
plt.tight_layout()
plt.show()

Qualitative Results

Visual comparison of full-image predictions against ground truth masks.

def show_test_predictions(model, data_loader, max_items=4, threshold=0.5):
    model.eval()
    shown = 0
 
    plt.figure(figsize=(16, 4 * max_items))
 
    with torch.no_grad():
        for images, masks, sample_ids in data_loader:
            images = images.to(device)
            masks = masks.to(device)
 
            logits_uniform = sliding_window_logits(
                model,
                images,
                tile_size=hyperparameters['inference_tile_size'],
                stride=hyperparameters['inference_stride'],
                blend='uniform',
            )
            logits_blend = sliding_window_logits(
                model,
                images,
                tile_size=hyperparameters['inference_tile_size'],
                stride=hyperparameters['inference_stride'],
                blend=hyperparameters.get('inference_blend', 'uniform'),
            )
 
            pred_uniform = (torch.sigmoid(logits_uniform) > threshold).float()
            pred_blend = (torch.sigmoid(logits_blend) > threshold).float()
 
            inter_u = (pred_uniform * masks).sum().item()
            union_u = pred_uniform.sum().item() + masks.sum().item()
            dice_u = (2.0 * inter_u + 1e-6) / (union_u + 1e-6)
 
            inter_b = (pred_blend * masks).sum().item()
            union_b = pred_blend.sum().item() + masks.sum().item()
            dice_b = (2.0 * inter_b + 1e-6) / (union_b + 1e-6)
 
            image_np = images[0].cpu().permute(1, 2, 0).numpy()
            gt_np = masks[0, 0].cpu().numpy()
            uniform_np = pred_uniform[0, 0].cpu().numpy()
            blend_np = pred_blend[0, 0].cpu().numpy()
 
            row = shown
            plt.subplot(max_items, 4, 4 * row + 1)
            plt.imshow(image_np)
            plt.title(f'{sample_ids[0]} - image')
            plt.axis('off')
 
            plt.subplot(max_items, 4, 4 * row + 2)
            plt.imshow(gt_np, cmap='gray')
            plt.title('ground truth')
            plt.axis('off')
 
            plt.subplot(max_items, 4, 4 * row + 3)
            plt.imshow(uniform_np, cmap='gray')
            plt.title(f'uniform merge (Dice={dice_u:.3f})')
            plt.axis('off')
 
            plt.subplot(max_items, 4, 4 * row + 4)
            plt.imshow(blend_np, cmap='gray')
            plt.title(f"{hyperparameters.get('inference_blend', 'uniform')} merge (Dice={dice_b:.3f})")
            plt.axis('off')
 
            shown += 1
            if shown >= max_items:
                break
                
    plt.tight_layout()
    plt.show()
    
    
show_test_predictions(model, test_loader, max_items=4)