from pathlib import Path
from urllib.request import Request, urlopen
import json
import shutil
import zipfile
 
RESOURCE_MANIFESTS = [
    "https://assets.deeplearningnotes.com/code-support-resources/datasets/bccd/latest.json",
]
 
def download_file(url, path):
    request = Request(url, headers={"User-Agent": "Mozilla/5.0"})
    with urlopen(request) as response, open(path, "wb") as file:
        shutil.copyfileobj(response, file)
 
def extract_zip_safely(zip_path, target_dir="."):
    target_dir = Path(target_dir).resolve()
 
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        for member in zip_ref.infolist():
            target_path = (target_dir / member.filename).resolve()
            if not str(target_path).startswith(str(target_dir)):
                raise RuntimeError(f"Unsafe zip path: {member.filename}")
 
        zip_ref.extractall(target_dir)
 
for manifest_url in RESOURCE_MANIFESTS:
    name = manifest_url.rstrip("/").split("/")[-2]
    manifest_path = Path(f"{name}-latest.json")
    archive_path = Path(f"{name}.zip")
 
    download_file(manifest_url, manifest_path)
    manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
 
    expected_paths = [Path(path) for path in manifest.get("expected_paths", [])]
 
    if not all(path.exists() for path in expected_paths):
        download_file(manifest["archive_url"], archive_path)
        extract_zip_safely(archive_path, manifest.get("extract_to", "."))
 
    missing = [str(path) for path in expected_paths if not path.exists()]
    if missing:
        raise FileNotFoundError(f"Missing expected paths: {missing}")
 
    print(f"{manifest['name']} ready.")

Imports

Libraries for blood cell object detection and classification with RetinaNet on BCCD.

import copy
import os
import random
import subprocess
import sys
import time
import xml.etree.ElementTree as ET
from collections import OrderedDict, defaultdict
from pathlib import Path
 
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import functional as TF
from torchvision.ops import box_iou
from torchvision.models import (
    ResNet18_Weights,
    ResNet34_Weights,
    ResNet50_Weights,
    ResNet101_Weights,
    ResNet152_Weights,
    resnet18,
    resnet34,
    resnet50,
    resnet101,
    resnet152,
)
try:
    from torchvision.models import (
        Swin_B_Weights,
        Swin_S_Weights,
        Swin_T_Weights,
        swin_b,
        swin_s,
        swin_t,
    )
except ImportError:
    Swin_B_Weights = None
    Swin_S_Weights = None
    Swin_T_Weights = None
    swin_b = None
    swin_s = None
    swin_t = None
from torchvision.models.detection import (
    RetinaNet,
    RetinaNet_ResNet50_FPN_Weights,
    retinanet_resnet50_fpn,
)
from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelP6P7
 
try:
    import timm
except ImportError:
    timm = None
 
try:
    import wandb
except ImportError:
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'wandb'])
    import wandb
 
print('PyTorch version:', torch.__version__)
print('Wandb version:', wandb.__version__)

Hyperparameters

Configure paths, training settings, and pretrained-vs-scratch head options.

hyperparameters = {
    'dataset_roots': [
        './datasets/BCCD',
    ],
    'batch_size': 4,
    'num_epochs': 20,
    'learning_rate': 3e-4,
    'weight_decay': 1e-4,
    'num_workers': 4,
    'persistent_workers': True,
    'random_seed': 42,
 
    # GPU selection (default: first available GPU)
    'gpu_id': 0,
 
    # RetinaNet initialization options
    # Swin options available out of the box: swin_t, swin_s, swin_b.
    # Use swin_l only if timm is already installed on the remote server.
    'backbone_name': 'swin_t',
    'use_pretrained_backbone': True,
    'use_pretrained_retinanet_head': True,
    'train_head_from_scratch': False,
 
    # Detection settings
    'score_threshold_eval': 0.05,
    'score_threshold_vis': 0.25,
    'nms_iou_threshold': 0.5,
    'max_detections_per_image': 200,
    'map_iou_threshold': 0.5,
 
    # Scheduler settings
    'scheduler_use_restarts': True,
    'scheduler_period_epochs': 7,  # Used only if scheduler_use_restarts=True
 
    # Weights & Biases settings
    'use_wandb': True,
    'wandb_project': 'deep-learning-2026-retinanet',
    'wandb_run_name': 'retinanet-swin-t-bccd-warmstart-head-noaug-restart',
    'wandb_entity': None,  # Set your W&B username/team if needed, else keep None
    'wandb_log_model': False,
    'wandb_watch_model': True,
}
 
seed = hyperparameters['random_seed']
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
 
requested_gpu_id = hyperparameters.get('gpu_id', 0)
if torch.cuda.is_available():
    n_gpus = torch.cuda.device_count()
    if requested_gpu_id < n_gpus:
        device = torch.device(f'cuda:{requested_gpu_id}')
    else:
        print(f"Requested gpu_id={requested_gpu_id}, but only {n_gpus} GPU(s) available. Falling back to cuda:0.")
        device = torch.device('cuda:0')
else:
    device = torch.device('cpu')
 
print('Using device:', device)
print('Random seed:', seed)
 
if hyperparameters.get('use_wandb', True):
    os.environ['WANDB_API_KEY'] = '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'
    print('W&B logging is enabled. Run name:', hyperparameters['wandb_run_name'])
else:
    print('W&B logging is disabled.')
 
 
def is_bccd_root(p):
    return (p / 'Annotations').exists() and (p / 'JPEGImages').exists() and (p / 'ImageSets').exists()
 
 
def resolve_bccd_root(roots):
    for root in roots:
        p = Path(root)
        if is_bccd_root(p):
            return p
 
    # Fallback: search from current working directory.
    cwd = Path.cwd()
    for candidate in cwd.rglob('*'):
        if candidate.is_dir() and candidate.name in {'BCCD', 'BCCD_Dataset'}:
            if is_bccd_root(candidate):
                return candidate
            nested = candidate / 'BCCD'
            if is_bccd_root(nested):
                return nested
 
    raise FileNotFoundError(
        f"Could not find BCCD root in any of: {roots}. Searched recursively from {Path.cwd()} as fallback."
    )
 
 
def find_split_file(root, split_name):
    candidates = [
        root / 'ImageSets' / f'{split_name}.txt',
        root / 'ImageSets' / 'Main' / f'{split_name}.txt',
    ]
    for p in candidates:
        if p.exists():
            return p
    raise FileNotFoundError(f"Could not find split file for '{split_name}' in {root / 'ImageSets'}")
 
 
dataset_root = resolve_bccd_root(hyperparameters['dataset_roots'])
annotations_root = dataset_root / 'Annotations'
images_root = dataset_root / 'JPEGImages'
 
split_files = {
    'train': find_split_file(dataset_root, 'train'),
    'val': find_split_file(dataset_root, 'val'),
    'test': find_split_file(dataset_root, 'test'),
}
 
print('Using BCCD root:', dataset_root.resolve())
print('Split files:')
for split_name, split_path in split_files.items():
    print(f"  {split_name}: {split_path.resolve()}")

RetinaNet Model Setup

Build RetinaNet with a selectable Swin Transformer backbone. Use swin_t, swin_s, or swin_b directly, and swin_l if timm is already installed on the remote server.

def reset_detection_heads(model, num_classes, train_head_from_scratch=False):
    cls_head = model.head.classification_head
    reg_head = model.head.regression_head
 
    num_anchors = cls_head.num_anchors
    in_channels = cls_head.cls_logits.in_channels
    old_cls_logits = cls_head.cls_logits
    old_num_classes = old_cls_logits.out_channels // max(1, num_anchors)
 
    # Keep the pretrained conv tower and replace only class logits.
    # When possible, warm-start logits from pretrained weights to stabilize epoch-1 mAP.
    new_cls_logits = nn.Conv2d(
        in_channels,
        num_anchors * num_classes,
        kernel_size=3,
        stride=1,
        padding=1,
    )
 
    if not train_head_from_scratch and old_cls_logits.weight.shape[1] == in_channels:
        with torch.no_grad():
            old_w = old_cls_logits.weight.data.view(num_anchors, old_num_classes, in_channels, 3, 3)
            old_b = old_cls_logits.bias.data.view(num_anchors, old_num_classes)
 
            if num_classes == old_num_classes:
                new_w = old_w
                new_b = old_b
            else:
                # Mean across old classes per anchor gives a stable prior for new classes.
                proto_w = old_w.mean(dim=1, keepdim=True)
                proto_b = old_b.mean(dim=1, keepdim=True)
                new_w = proto_w.repeat(1, num_classes, 1, 1, 1)
                new_b = proto_b.repeat(1, num_classes)
 
            new_cls_logits.weight.copy_(new_w.reshape(num_anchors * num_classes, in_channels, 3, 3))
            new_cls_logits.bias.copy_(new_b.reshape(num_anchors * num_classes))
    else:
        nn.init.normal_(new_cls_logits.weight, std=0.01)
        nn.init.constant_(new_cls_logits.bias, -np.log((1 - 0.01) / 0.01))
    cls_head.cls_logits = new_cls_logits
    cls_head.num_classes = num_classes
 
    if train_head_from_scratch:
        for module in cls_head.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.normal_(module.weight, std=0.01)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
 
        for module in reg_head.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.normal_(module.weight, std=0.01)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
 
        # Focal loss bias initialization after generic reinit
        nn.init.constant_(cls_head.cls_logits.bias, -np.log((1 - 0.01) / 0.01))
 
    return model
 
 
RESNET_BUILDERS = {
    'resnet18': resnet18,
    'resnet34': resnet34,
    'resnet50': resnet50,
    'resnet101': resnet101,
    'resnet152': resnet152,
}
 
RESNET_WEIGHTS = {
    'resnet18': ResNet18_Weights,
    'resnet34': ResNet34_Weights,
    'resnet50': ResNet50_Weights,
    'resnet101': ResNet101_Weights,
    'resnet152': ResNet152_Weights,
}
 
SWIN_TORCHVISION_BUILDERS = {
    'swin_t': swin_t,
    'swin_s': swin_s,
    'swin_b': swin_b,
}
 
SWIN_TORCHVISION_WEIGHTS = {
    'swin_t': Swin_T_Weights,
    'swin_s': Swin_S_Weights,
    'swin_b': Swin_B_Weights,
}
 
SWIN_TORCHVISION_EMBED_DIMS = {
    'swin_t': 96,
    'swin_s': 96,
    'swin_b': 128,
}
 
TIMM_SWIN_MODELS = {
    'swin_l': 'swin_large_patch4_window7_224',
}
 
 
def normalize_backbone_name(backbone_name):
    normalized = str(backbone_name).strip().lower().replace('-', '_')
    aliases = {
        'tiny': 'swin_t',
        'swin_tiny': 'swin_t',
        'small': 'swin_s',
        'swin_small': 'swin_s',
        'base': 'swin_b',
        'swin_base': 'swin_b',
        'large': 'swin_l',
        'swin_large': 'swin_l',
    }
    return aliases.get(normalized, normalized)
 
 
def format_backbone_name(backbone_name):
    normalized = normalize_backbone_name(backbone_name)
    pretty_names = {
        'resnet18': 'ResNet-18',
        'resnet34': 'ResNet-34',
        'resnet50': 'ResNet-50',
        'resnet101': 'ResNet-101',
        'resnet152': 'ResNet-152',
        'swin_t': 'Swin-T',
        'swin_s': 'Swin-S',
        'swin_b': 'Swin-B',
        'swin_l': 'Swin-L',
    }
    return pretty_names.get(normalized, normalized.replace('_', '-'))
 
 
class TorchvisionSwinBackboneWithFPN(nn.Module):
    def __init__(self, backbone_body, in_channels_list, out_channels=256):
        super().__init__()
        self.features = backbone_body.features
        self.final_norm = getattr(backbone_body, 'norm', None)
        self.stage_indices = (3, 5, 7)
        self.out_channels = out_channels
        self.fpn = FeaturePyramidNetwork(
            in_channels_list=in_channels_list,
            out_channels=out_channels,
            extra_blocks=LastLevelP6P7(in_channels_list[-1], out_channels),
        )
 
    def forward(self, x):
        stage_outputs = OrderedDict()
        feature_index = 0
 
        for layer_index, layer in enumerate(self.features):
            x = layer(x)
 
            if layer_index == self.stage_indices[-1] and self.final_norm is not None:
                x = self.final_norm(x)
 
            if layer_index in self.stage_indices:
                stage_outputs[str(feature_index)] = x.permute(0, 3, 1, 2).contiguous()
                feature_index += 1
 
        return self.fpn(stage_outputs)
 
 
class TimmFeaturesBackboneWithFPN(nn.Module):
    def __init__(self, backbone_body, out_channels=256):
        super().__init__()
        self.body = backbone_body
        in_channels_list = list(backbone_body.feature_info.channels())
        self.out_channels = out_channels
        self.fpn = FeaturePyramidNetwork(
            in_channels_list=in_channels_list,
            out_channels=out_channels,
            extra_blocks=LastLevelP6P7(in_channels_list[-1], out_channels),
        )
 
    def forward(self, x):
        features = self.body(x)
        if isinstance(features, tuple):
            features = list(features)
        stage_outputs = OrderedDict((str(idx), feature) for idx, feature in enumerate(features))
        return self.fpn(stage_outputs)
 
 
def build_swin_backbone(backbone_name, use_pretrained_backbone):
    backbone_name = normalize_backbone_name(backbone_name)
 
    if backbone_name in SWIN_TORCHVISION_BUILDERS:
        builder = SWIN_TORCHVISION_BUILDERS[backbone_name]
        weights_enum = SWIN_TORCHVISION_WEIGHTS[backbone_name]
        if builder is None or weights_enum is None:
            raise RuntimeError(
                'The current torchvision build does not expose Swin Transformer backbones. '
                'Use a ResNet backbone or update torchvision on the remote server.'
            )
 
        weights = weights_enum.DEFAULT if use_pretrained_backbone else None
        backbone_body = builder(weights=weights, progress=False)
        embed_dim = SWIN_TORCHVISION_EMBED_DIMS[backbone_name]
        in_channels_list = [embed_dim * 2, embed_dim * 4, embed_dim * 8]
        return TorchvisionSwinBackboneWithFPN(backbone_body, in_channels_list=in_channels_list)
 
    timm_model_name = TIMM_SWIN_MODELS.get(backbone_name)
    if timm_model_name is None:
        supported_backbones = sorted(set(RESNET_BUILDERS) | set(SWIN_TORCHVISION_BUILDERS) | set(TIMM_SWIN_MODELS))
        raise ValueError(
            f"Unsupported backbone_name='{backbone_name}'. Supported: {supported_backbones}"
        )
 
    if timm is None:
        raise RuntimeError(
            f"backbone_name='{backbone_name}' requires timm on the remote server. "
            'Use swin_t, swin_s or swin_b with torchvision only, or install timm first.'
        )
 
    try:
        backbone_body = timm.create_model(
            timm_model_name,
            pretrained=use_pretrained_backbone,
            features_only=True,
            out_indices=(1, 2, 3),
        )
    except Exception as exc:
        raise RuntimeError(
            f"Could not build timm backbone '{timm_model_name}'. Make sure the remote server has a compatible timm version."
        ) from exc
 
    return TimmFeaturesBackboneWithFPN(backbone_body)
 
 
def build_detection_backbone(backbone_name, use_pretrained_backbone):
    if backbone_name in RESNET_BUILDERS:
        backbone_weights = RESNET_WEIGHTS[backbone_name].DEFAULT if use_pretrained_backbone else None
        backbone_body = RESNET_BUILDERS[backbone_name](weights=backbone_weights, progress=False)
 
        # For LastLevelP6P7, the input channels match the C5 stage output channels.
        c5_channels = backbone_body.fc.in_features
        return _resnet_fpn_extractor(
            backbone_body,
            trainable_layers=5,
            returned_layers=[2, 3, 4],
            extra_blocks=LastLevelP6P7(c5_channels, 256),
        )
 
    return build_swin_backbone(backbone_name, use_pretrained_backbone)
 
 
def build_model(hparams, num_classes):
    use_pretrained_backbone = hparams.get('use_pretrained_backbone', True)
    use_pretrained_head = hparams.get('use_pretrained_retinanet_head', False)
    train_head_from_scratch = hparams.get('train_head_from_scratch', False)
    backbone_name = normalize_backbone_name(hparams.get('backbone_name', 'swin_t'))
 
    supported_backbones = sorted(set(RESNET_BUILDERS) | set(SWIN_TORCHVISION_BUILDERS) | set(TIMM_SWIN_MODELS))
    if backbone_name not in supported_backbones:
        raise ValueError(
            f"Unsupported backbone_name='{backbone_name}'. Supported: {supported_backbones}"
        )
 
    backbone = build_detection_backbone(backbone_name, use_pretrained_backbone)
 
    if use_pretrained_head:
        model = retinanet_resnet50_fpn(
            weights=RetinaNet_ResNet50_FPN_Weights.DEFAULT,
            progress=False,
        )
        model.backbone = backbone
        model = reset_detection_heads(
            model,
            num_classes=num_classes,
            train_head_from_scratch=train_head_from_scratch,
        )
    else:
        model = RetinaNet(
            backbone=backbone,
            num_classes=num_classes,
        )
 
    model.transform.min_size = (800,)
    model.transform.max_size = 1333
    model.score_thresh = hparams.get('score_threshold_eval', 0.05)
    model.nms_thresh = hparams.get('nms_iou_threshold', 0.5)
    model.detections_per_img = hparams.get('max_detections_per_image', 200)
    return model
 
 
def count_trainable_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

Optimization Setup

Define optimizer and scheduler for RetinaNet fine-tuning.

def build_optimizer_and_scheduler(model, hparams):
    optimizer = optim.AdamW(
        model.parameters(),
        lr=hparams['learning_rate'],
        weight_decay=hparams.get('weight_decay', 0.0),
    )
 
    use_restarts = bool(hparams.get('scheduler_use_restarts', True))
    period_epochs = max(1, int(hparams.get('scheduler_period_epochs', 5)))
    total_epochs = max(1, int(hparams.get('num_epochs', 1)))
 
    if use_restarts:
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=period_epochs,
            T_mult=1,
        )
    else:
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=total_epochs,
        )
 
    return optimizer, scheduler

Model Summary and Trainable Parameters

Inspect the model and verify how many parameters are trainable under the selected transfer mode.

_model = globals().get('model', None)
if _model is None:
    print('Model not found yet. Run the dataset/dataloader cell first, then re-run this cell.')
else:
    total_params, trainable_params = count_trainable_parameters(_model)
    backbone_name = hyperparameters.get('backbone_name', 'swin_t')
    print(f"Model: RetinaNet {format_backbone_name(backbone_name)}-FPN")
    print('use_pretrained_backbone:', hyperparameters['use_pretrained_backbone'])
    print('use_pretrained_retinanet_head:', hyperparameters['use_pretrained_retinanet_head'])
    print('train_head_from_scratch:', hyperparameters['train_head_from_scratch'])
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

Build BCCD Splits

Read predefined train/val/test splits and parse VOC XML annotations.

def read_split_ids(split_file):
    image_ids = []
    for line in split_file.read_text().splitlines():
        item = line.strip()
        if not item:
            continue
        stem = Path(item).stem
        image_ids.append(stem)
    return image_ids
 
 
def parse_voc_annotation(xml_path):
    root = ET.parse(xml_path).getroot()
 
    boxes = []
    labels = []
    for obj in root.findall('object'):
        cls_name = obj.findtext('name')
        if cls_name is None:
            continue
 
        bnd = obj.find('bndbox')
        if bnd is None:
            continue
 
        xmin = float(bnd.findtext('xmin', default='0'))
        ymin = float(bnd.findtext('ymin', default='0'))
        xmax = float(bnd.findtext('xmax', default='0'))
        ymax = float(bnd.findtext('ymax', default='0'))
 
        # VOC coordinates are usually 1-based inclusive; convert to 0-based half-open style.
        xmin = max(0.0, xmin - 1.0)
        ymin = max(0.0, ymin - 1.0)
        xmax = max(xmin + 1.0, xmax)
        ymax = max(ymin + 1.0, ymax)
 
        boxes.append([xmin, ymin, xmax, ymax])
        labels.append(cls_name)
 
    return boxes, labels
 
 
split_ids = {name: read_split_ids(path) for name, path in split_files.items()}
print('Split sizes:', {k: len(v) for k, v in split_ids.items()})
 
 
def discover_labels(split_id_dict, ann_root):
    labels = set()
    for image_ids in split_id_dict.values():
        for image_id in image_ids:
            xml_path = ann_root / f'{image_id}.xml'
            if not xml_path.exists():
                continue
            _, class_names = parse_voc_annotation(xml_path)
            labels.update(class_names)
 
    preferred = ['RBC', 'WBC', 'Platelets']
    if set(preferred).issubset(labels):
        return preferred
 
    return sorted(labels)
 
 
class_names = discover_labels(split_ids, annotations_root)
if len(class_names) != 3:
    print(f'Warning: discovered {len(class_names)} classes: {class_names}')
 
label_to_id = {name: i + 1 for i, name in enumerate(class_names)}
id_to_label = {i: name for name, i in label_to_id.items()}
num_classes = len(class_names) + 1  # include background class 0
 
print('Class mapping (background=0):')
for cls_name, cls_id in label_to_id.items():
    print(f'  {cls_id}: {cls_name}')
 
 
class BCCDDataset(Dataset):
    def __init__(self, image_ids, images_root, annotations_root, label_to_id):
        self.image_ids = list(image_ids)
        self.images_root = Path(images_root)
        self.annotations_root = Path(annotations_root)
        self.label_to_id = dict(label_to_id)
 
    def __len__(self):
        return len(self.image_ids)
 
    def _resolve_image_path(self, image_id):
        for ext in ['.jpg', '.jpeg', '.png', '.bmp']:
            p = self.images_root / f'{image_id}{ext}'
            if p.exists():
                return p
        raise FileNotFoundError(f'No image found for ID {image_id} in {self.images_root}')
 
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_path = self._resolve_image_path(image_id)
        xml_path = self.annotations_root / f'{image_id}.xml'
 
        image = Image.open(image_path).convert('RGB')
        width, height = image.size
 
        boxes, class_names = parse_voc_annotation(xml_path)
        labels = [self.label_to_id[n] for n in class_names if n in self.label_to_id]
 
        if len(boxes) == 0:
            boxes_tensor = torch.zeros((0, 4), dtype=torch.float32)
            labels_tensor = torch.zeros((0,), dtype=torch.int64)
        else:
            boxes_tensor = torch.tensor(boxes, dtype=torch.float32)
            labels_tensor = torch.tensor(labels, dtype=torch.int64)
 
            boxes_tensor[:, 0::2] = boxes_tensor[:, 0::2].clamp(0, width - 1)
            boxes_tensor[:, 1::2] = boxes_tensor[:, 1::2].clamp(0, height - 1)
 
        image_np = np.array(image)
        image_tensor = TF.to_tensor(image_np)
 
        if boxes_tensor.numel() > 0:
            boxes_tensor = boxes_tensor.reshape(-1, 4)
            boxes_tensor[:, 0::2] = boxes_tensor[:, 0::2].clamp(0, width - 1)
            boxes_tensor[:, 1::2] = boxes_tensor[:, 1::2].clamp(0, height - 1)
 
            wh = boxes_tensor[:, 2:] - boxes_tensor[:, :2]
            valid = (wh[:, 0] > 1.0) & (wh[:, 1] > 1.0)
            boxes_tensor = boxes_tensor[valid]
            labels_tensor = labels_tensor[valid]
 
        area = (boxes_tensor[:, 2] - boxes_tensor[:, 0]) * (boxes_tensor[:, 3] - boxes_tensor[:, 1])
        iscrowd = torch.zeros((boxes_tensor.shape[0],), dtype=torch.int64)
 
        target = {
            'boxes': boxes_tensor,
            'labels': labels_tensor,
            'image_id': torch.tensor([idx], dtype=torch.int64),
            'area': area,
            'iscrowd': iscrowd,
            'image_name': image_id,
            'size': torch.tensor([height, width], dtype=torch.int64),
        }
        return image_tensor, target
 
 
def detection_collate_fn(batch):
    images, targets = zip(*batch)
    return list(images), list(targets)

Annotation Sanity Check

Visualize sample images with ground-truth bounding boxes and class labels.

def draw_boxes(ax, image_np, boxes, labels, class_map, color='lime', linewidth=2, alpha=1.0):
    ax.imshow(image_np)
    ax.axis('off')
    for box, label in zip(boxes, labels):
        x1, y1, x2, y2 = box
        w, h = x2 - x1, y2 - y1
        rect = patches.Rectangle(
            (x1, y1), w, h, linewidth=linewidth, edgecolor=color, facecolor='none', alpha=alpha
        )
        ax.add_patch(rect)
        txt = class_map.get(int(label), str(int(label)))
        ax.text(
            x1,
            max(0, y1 - 3),
            txt,
            color='white',
            fontsize=9,
            bbox=dict(facecolor=color, alpha=0.7, pad=1.5, edgecolor='none'),
        )
 
 
def show_ground_truth_samples(dataset, max_items=6, cols=3):
    n = min(max_items, len(dataset))
    rows = int(np.ceil(n / cols))
    fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows))
    axes = np.array(axes).reshape(rows, cols)
 
    for i in range(rows * cols):
        ax = axes[i // cols, i % cols]
        if i >= n:
            ax.axis('off')
            continue
 
        image, target = dataset[i]
        img_np = image.permute(1, 2, 0).numpy()
        boxes = target['boxes'].numpy()
        labels = target['labels'].numpy()
 
        draw_boxes(ax, img_np, boxes, labels, id_to_label, color='deepskyblue')
        ax.set_title(f"{target['image_name']} ({len(labels)} objects)")
 
    plt.tight_layout()
    plt.show()
 
 
_train_vis_dataset = BCCDDataset(
    split_ids['train'],
    images_root=images_root,
    annotations_root=annotations_root,
    label_to_id=label_to_id,
)
show_ground_truth_samples(_train_vis_dataset, max_items=6, cols=3)

Create Detection Datasets

Instantiate train/val/test datasets using predefined ImageSets splits.

train_dataset = BCCDDataset(
    split_ids['train'],
    images_root=images_root,
    annotations_root=annotations_root,
    label_to_id=label_to_id,
    )
 
train_eval_dataset = BCCDDataset(
    split_ids['train'],
    images_root=images_root,
    annotations_root=annotations_root,
    label_to_id=label_to_id,
    )
 
val_dataset = BCCDDataset(
    split_ids['val'],
    images_root=images_root,
    annotations_root=annotations_root,
    label_to_id=label_to_id,
    )
 
test_dataset = BCCDDataset(
    split_ids['test'],
    images_root=images_root,
    annotations_root=annotations_root,
    label_to_id=label_to_id,
    )
 
print('Datasets ready:')
print('  train:', len(train_dataset))
print('  train_eval:', len(train_eval_dataset))
print('  val:', len(val_dataset))
print('  test:', len(test_dataset))

DataLoaders And Model

Build detection dataloaders and instantiate RetinaNet.

loader_num_workers = hyperparameters.get('num_workers', 2)
use_persistent_workers = hyperparameters.get('persistent_workers', True) and loader_num_workers > 0
pin_memory = torch.cuda.is_available()
 
train_loader = DataLoader(
    train_dataset,
    batch_size=hyperparameters['batch_size'],
    shuffle=True,
    num_workers=loader_num_workers,
    pin_memory=pin_memory,
    persistent_workers=use_persistent_workers,
    collate_fn=detection_collate_fn,
    )
 
train_eval_loader = DataLoader(
    train_eval_dataset,
    batch_size=hyperparameters['batch_size'],
    shuffle=False,
    num_workers=loader_num_workers,
    pin_memory=pin_memory,
    persistent_workers=use_persistent_workers,
    collate_fn=detection_collate_fn,
    )
 
val_loader = DataLoader(
    val_dataset,
    batch_size=hyperparameters['batch_size'],
    shuffle=False,
    num_workers=loader_num_workers,
    pin_memory=pin_memory,
    persistent_workers=use_persistent_workers,
    collate_fn=detection_collate_fn,
    )
 
test_loader = DataLoader(
    test_dataset,
    batch_size=hyperparameters['batch_size'],
    shuffle=False,
    num_workers=loader_num_workers,
    pin_memory=pin_memory,
    persistent_workers=use_persistent_workers,
    collate_fn=detection_collate_fn,
    )
 
model = build_model(hyperparameters, num_classes=num_classes).to(device)
optimizer, scheduler = build_optimizer_and_scheduler(model, hyperparameters)
 
total_params, trainable_params = count_trainable_parameters(model)
backbone_name = hyperparameters.get('backbone_name', 'swin_t')
print(f"Model: RetinaNet {format_backbone_name(backbone_name)}-FPN")
print(f"Trainable parameters: {trainable_params:,} / {total_params:,}")

Batch Visualization

Inspect one training batch with detection targets.

images, targets = next(iter(train_loader))
print('Batch size:', len(images))
print('Example image tensor shape:', tuple(images[0].shape))
print('Example targets keys:', list(targets[0].keys()))
 
n_show = min(4, len(images))
fig, axes = plt.subplots(1, n_show, figsize=(5 * n_show, 5))
if n_show == 1:
    axes = [axes]
 
for i in range(n_show):
    img_np = images[i].permute(1, 2, 0).numpy()
    boxes = targets[i]['boxes'].numpy()
    labels = targets[i]['labels'].numpy()
 
    draw_boxes(axes[i], img_np, boxes, labels, id_to_label, color='gold')
    axes[i].set_title(f"{targets[i]['image_name']} ({len(labels)} objects)")
 
plt.tight_layout()
plt.show()

Training Step

One training epoch for RetinaNet using detection losses.

def train_one_epoch(model, train_loader, optimizer, device):
    model.train()
    running_loss = 0.0
    running_cls_loss = 0.0
    running_reg_loss = 0.0
    n_batches = 0
 
    for images, targets in train_loader:
        images = [img.to(device, non_blocking=True) for img in images]
        targets_device = []
        for t in targets:
            targets_device.append({
                'boxes': t['boxes'].to(device, non_blocking=True),
                'labels': t['labels'].to(device, non_blocking=True),
                'image_id': t['image_id'].to(device, non_blocking=True),
                'area': t['area'].to(device, non_blocking=True),
                'iscrowd': t['iscrowd'].to(device, non_blocking=True),
            })
 
        loss_dict = model(images, targets_device)
        loss = sum(loss_dict.values())
 
        cls_loss = loss_dict.get('classification')
        if cls_loss is None:
            cls_candidates = [v for k, v in loss_dict.items() if ('class' in k.lower() or 'cls' in k.lower())]
            cls_loss = sum(cls_candidates) if cls_candidates else torch.tensor(0.0, device=loss.device)
 
        reg_loss = loss_dict.get('bbox_regression')
        if reg_loss is None:
            reg_candidates = [v for k, v in loss_dict.items() if ('reg' in k.lower() or 'bbox' in k.lower() or 'box' in k.lower())]
            reg_loss = sum(reg_candidates) if reg_candidates else torch.tensor(0.0, device=loss.device)
 
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
 
        running_loss += float(loss.item())
        running_cls_loss += float(cls_loss.item())
        running_reg_loss += float(reg_loss.item())
        n_batches += 1
 
    avg_loss = running_loss / max(1, n_batches)
    avg_cls_loss = running_cls_loss / max(1, n_batches)
    avg_reg_loss = running_reg_loss / max(1, n_batches)
    return avg_loss, avg_cls_loss, avg_reg_loss

mAP Evaluation

Compute VOC-style AP per class and mAP at IoU=0.5.

def compute_ap_from_detections(scores, is_tp, total_gt):
    if total_gt == 0:
        return np.nan
 
    if len(scores) == 0:
        return 0.0
 
    order = np.argsort(-np.asarray(scores))
    tp = np.asarray(is_tp)[order].astype(np.float32)
    fp = 1.0 - tp
 
    tp_cum = np.cumsum(tp)
    fp_cum = np.cumsum(fp)
 
    recall = tp_cum / max(1.0, float(total_gt))
    precision = tp_cum / np.maximum(tp_cum + fp_cum, 1e-8)
 
    mrec = np.concatenate(([0.0], recall, [1.0]))
    mpre = np.concatenate(([0.0], precision, [0.0]))
 
    for i in range(mpre.size - 1, 0, -1):
        mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
 
    idx = np.where(mrec[1:] != mrec[:-1])[0]
    ap = np.sum((mrec[idx + 1] - mrec[idx]) * mpre[idx + 1])
    return float(ap)
 
 
@torch.no_grad()
def evaluate_map(model, data_loader, class_ids, iou_threshold=0.5, score_threshold=0.05):
    model.eval()
 
    per_class_scores = {c: [] for c in class_ids}
    per_class_tp = {c: [] for c in class_ids}
    per_class_total_gt = {c: 0 for c in class_ids}
 
    for images, targets in data_loader:
        images_device = [img.to(device, non_blocking=True) for img in images]
        outputs = model(images_device)
 
        for output, target in zip(outputs, targets):
            gt_boxes = target['boxes']
            gt_labels = target['labels']
 
            pred_boxes = output['boxes'].detach().cpu()
            pred_scores = output['scores'].detach().cpu()
            pred_labels = output['labels'].detach().cpu()
 
            keep = pred_scores >= score_threshold
            pred_boxes = pred_boxes[keep]
            pred_scores = pred_scores[keep]
            pred_labels = pred_labels[keep]
 
            for cls_id in class_ids:
                gt_mask = gt_labels == cls_id
                cls_gt_boxes = gt_boxes[gt_mask]
                per_class_total_gt[cls_id] += int(cls_gt_boxes.shape[0])
 
                pred_mask = pred_labels == cls_id
                cls_pred_boxes = pred_boxes[pred_mask]
                cls_pred_scores = pred_scores[pred_mask]
 
                if cls_pred_scores.numel() == 0:
                    continue
 
                order = torch.argsort(cls_pred_scores, descending=True)
                cls_pred_boxes = cls_pred_boxes[order]
                cls_pred_scores = cls_pred_scores[order]
 
                matched_gt = torch.zeros((cls_gt_boxes.shape[0],), dtype=torch.bool)
 
                for p_box, p_score in zip(cls_pred_boxes, cls_pred_scores):
                    if cls_gt_boxes.shape[0] == 0:
                        per_class_scores[cls_id].append(float(p_score))
                        per_class_tp[cls_id].append(0.0)
                        continue
 
                    ious = box_iou(p_box.unsqueeze(0), cls_gt_boxes).squeeze(0)
                    best_iou, best_idx = torch.max(ious, dim=0)
 
                    if best_iou >= iou_threshold and not matched_gt[best_idx]:
                        matched_gt[best_idx] = True
                        per_class_scores[cls_id].append(float(p_score))
                        per_class_tp[cls_id].append(1.0)
                    else:
                        per_class_scores[cls_id].append(float(p_score))
                        per_class_tp[cls_id].append(0.0)
 
    ap_per_class = {}
    for cls_id in class_ids:
        ap_per_class[cls_id] = compute_ap_from_detections(
            per_class_scores[cls_id],
            per_class_tp[cls_id],
            per_class_total_gt[cls_id],
        )
 
    valid_aps = [v for v in ap_per_class.values() if not np.isnan(v)]
    map_value = float(np.mean(valid_aps)) if len(valid_aps) > 0 else float('nan')
    return map_value, ap_per_class

Train And Test

Train RetinaNet, monitor mAP on train/val every epoch, and report final test mAP.

train_losses = []
train_maps = []
val_maps = []
 
best_val_map = -1.0
best_epoch = 0
best_model_state = copy.deepcopy(model.state_dict())
 
class_ids = sorted(id_to_label.keys())
 
wandb_run = None
if hyperparameters.get('use_wandb', True):
    wandb_run = wandb.init(
        project=hyperparameters.get('wandb_project', 'deep-learning-2026-retinanet'),
        entity=hyperparameters.get('wandb_entity', None),
        name=hyperparameters.get('wandb_run_name', None),
        config={
            **hyperparameters,
            'num_classes': num_classes,
            'class_names': class_names,
            'device': str(device),
            'dataset_root': str(dataset_root),
            'train_size': len(train_dataset),
            'val_size': len(val_dataset),
            'test_size': len(test_dataset),
        },
    )
    if hyperparameters.get('wandb_watch_model', True):
        wandb.watch(model, log='all', log_freq=100)
 
for epoch in range(hyperparameters['num_epochs']):
    start = time.time()
 
    train_loss, train_cls_loss, train_reg_loss = train_one_epoch(model, train_loader, optimizer, device)
    scheduler.step()
 
    train_map, train_ap = evaluate_map(
        model,
        train_eval_loader,
        class_ids=class_ids,
        iou_threshold=hyperparameters['map_iou_threshold'],
        score_threshold=hyperparameters['score_threshold_eval'],
    )
    val_map, val_ap = evaluate_map(
        model,
        val_loader,
        class_ids=class_ids,
        iou_threshold=hyperparameters['map_iou_threshold'],
        score_threshold=hyperparameters['score_threshold_eval'],
    )
 
    train_losses.append(train_loss)
    train_maps.append(train_map)
    val_maps.append(val_map)
 
    elapsed = time.time() - start
    lr = optimizer.param_groups[0]['lr']
    print(
        f"Epoch {epoch + 1:02d}/{hyperparameters['num_epochs']} | "
        f"{elapsed:.1f}s | lr={lr:.2e} | loss={train_loss:.4f} (cls={train_cls_loss:.4f}, reg={train_reg_loss:.4f}) | "
        f"train mAP@{hyperparameters['map_iou_threshold']:.2f}={train_map:.4f} | "
        f"val mAP@{hyperparameters['map_iou_threshold']:.2f}={val_map:.4f}"
    )
 
    if val_map > best_val_map:
        best_val_map = val_map
        best_epoch = epoch + 1
        best_model_state = copy.deepcopy(model.state_dict())
 
    if wandb_run is not None:
        epoch_log = {
            'epoch': epoch + 1,
            'train/loss': float(train_loss),
            'train/loss_classification': float(train_cls_loss),
            'train/loss_regression': float(train_reg_loss),
            'train/map': float(train_map),
            'val/map': float(val_map),
            'train/val_gap_map': float(train_map - val_map),
            'train/lr': float(lr),
            'train/epoch_time_sec': float(elapsed),
            'best/val_map_so_far': float(best_val_map),
            'best/epoch_so_far': int(best_epoch),
        }
 
        for cls_id in class_ids:
            cls_name = id_to_label[cls_id]
            epoch_log[f'train_ap/{cls_name}'] = float(train_ap[cls_id])
            epoch_log[f'val_ap/{cls_name}'] = float(val_ap[cls_id])
 
        wandb.log(epoch_log, step=epoch + 1)
 
model.load_state_dict(best_model_state)
 
test_map, test_ap = evaluate_map(
    model,
    test_loader,
    class_ids=class_ids,
    iou_threshold=hyperparameters['map_iou_threshold'],
    score_threshold=hyperparameters['score_threshold_eval'],
)
 
print()
print(f"Best validation mAP@{hyperparameters['map_iou_threshold']:.2f}: {best_val_map:.4f} at epoch {best_epoch}")
print(f"Test mAP@{hyperparameters['map_iou_threshold']:.2f}: {test_map:.4f}")
print('Test AP per class:')
for cls_id in class_ids:
    print(f"  {id_to_label[cls_id]}: {test_ap[cls_id]:.4f}")
 
if wandb_run is not None:
    final_log = {
        'best/epoch': int(best_epoch),
        'best/val_map': float(best_val_map),
        'test/map': float(test_map),
    }
    for cls_id in class_ids:
        cls_name = id_to_label[cls_id]
        final_log[f'test_ap/{cls_name}'] = float(test_ap[cls_id])
    wandb.log(final_log, step=hyperparameters['num_epochs'])
 
    if hyperparameters.get('wandb_log_model', False):
        model_artifact = wandb.Artifact(
            name=f"{wandb.run.name}-weights",
            type='model',
            description='Best RetinaNet weights (by validation mAP)',
        )
        weights_path = Path('retinanet_best_model.pth')
        torch.save(best_model_state, weights_path)
        model_artifact.add_file(str(weights_path))
        wandb.log_artifact(model_artifact)
 
epochs = np.arange(1, len(train_losses) + 1)
plt.figure(figsize=(12, 4))
 
plt.subplot(1, 2, 1)
plt.plot(epochs, train_losses, marker='o', label='Train loss')
plt.title('RetinaNet training loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(alpha=0.25)
plt.legend()
 
plt.subplot(1, 2, 2)
plt.plot(epochs, train_maps, marker='o', label='Train mAP')
plt.plot(epochs, val_maps, marker='o', label='Validation mAP')
plt.title(f"mAP@{hyperparameters['map_iou_threshold']:.2f} per epoch")
plt.xlabel('Epoch')
plt.ylabel('mAP')
plt.ylim(0, 1.0)
plt.grid(alpha=0.25)
plt.legend()
 
plt.tight_layout()
plt.show()
 
if wandb_run is not None:
    wandb.finish()

Qualitative Results

Object-detection style visualization of predictions vs ground truth.

def draw_prediction_boxes(ax, image_np, boxes, labels, scores, class_map, score_thr=0.25):
    ax.imshow(image_np)
    ax.axis('off')
    for box, label, score in zip(boxes, labels, scores):
        if score < score_thr:
            continue
 
        x1, y1, x2, y2 = box
        w, h = x2 - x1, y2 - y1
 
        rect = patches.Rectangle(
            (x1, y1), w, h, linewidth=2, edgecolor='crimson', facecolor='none'
        )
        ax.add_patch(rect)
 
        cls_name = class_map.get(int(label), str(int(label)))
        ax.text(
            x1,
            max(0, y1 - 3),
            f"{cls_name} {score:.2f}",
            color='white',
            fontsize=9,
            bbox=dict(facecolor='crimson', alpha=0.75, edgecolor='none', pad=1.5),
        )
 
 
@torch.no_grad()
def show_detection_results(model, data_loader, max_items=4, score_thr=0.25):
    model.eval()
    shown = 0
    fig, axes = plt.subplots(max_items, 2, figsize=(12, 5 * max_items))
    if max_items == 1:
        axes = np.array([axes])
 
    for images, targets in data_loader:
        images_device = [img.to(device) for img in images]
        outputs = model(images_device)
 
        for image, target, output in zip(images, targets, outputs):
            if shown >= max_items:
                break
 
            image_np = image.permute(1, 2, 0).numpy()
 
            gt_ax = axes[shown, 0]
            pred_ax = axes[shown, 1]
 
            draw_boxes(
                gt_ax,
                image_np,
                target['boxes'].numpy(),
                target['labels'].numpy(),
                id_to_label,
                color='dodgerblue',
            )
            gt_ax.set_title(f"GT: {target['image_name']}")
 
            draw_prediction_boxes(
                pred_ax,
                image_np,
                output['boxes'].detach().cpu().numpy(),
                output['labels'].detach().cpu().numpy(),
                output['scores'].detach().cpu().numpy(),
                id_to_label,
                score_thr=score_thr,
            )
            pred_ax.set_title(f"Prediction (score >= {score_thr:.2f})")
 
            shown += 1
 
        if shown >= max_items:
            break
 
    plt.tight_layout()
    plt.show()
 
 
show_detection_results(
    model,
    test_loader,
    max_items=4,
    score_thr=hyperparameters['score_threshold_vis'],
)
 
# Visual comparison at different confidence thresholds
for th in [0.15, 0.35, 0.55]:
    print(f'Visualization threshold: {th:.2f}')
    show_detection_results(model, test_loader, max_items=2, score_thr=th)
 
# Compact AP summary on train/val/test with the best model loaded
train_map_best, train_ap_best = evaluate_map(
    model,
    train_eval_loader,
    class_ids=class_ids,
    iou_threshold=hyperparameters['map_iou_threshold'],
    score_threshold=hyperparameters['score_threshold_eval'],
)
val_map_best, val_ap_best = evaluate_map(
    model,
    val_loader,
    class_ids=class_ids,
    iou_threshold=hyperparameters['map_iou_threshold'],
    score_threshold=hyperparameters['score_threshold_eval'],
)
 
print(f"Train mAP@{hyperparameters['map_iou_threshold']:.2f}: {train_map_best:.4f}")
for cls_id in class_ids:
    print(f"  Train AP {id_to_label[cls_id]}: {train_ap_best[cls_id]:.4f}")
 
print(f"Val mAP@{hyperparameters['map_iou_threshold']:.2f}: {val_map_best:.4f}")
for cls_id in class_ids:
    print(f"  Val AP {id_to_label[cls_id]}: {val_ap_best[cls_id]:.4f}")
 
print(f"Test mAP@{hyperparameters['map_iou_threshold']:.2f}: {test_map:.4f}")
for cls_id in class_ids:
    print(f"  Test AP {id_to_label[cls_id]}: {test_ap[cls_id]:.4f}")
 
# Final qualitative panel
show_detection_results(
    model,
    test_loader,
    max_items=6,
    score_thr=hyperparameters['score_threshold_vis'],
)