
from pathlib import Path
from urllib.request import Request, urlopen
import json
import shutil
import zipfile
RESOURCE_MANIFESTS = [
"https://assets.deeplearningnotes.com/code-support-resources/datasets/bccd/latest.json",
]
def download_file(url, path):
request = Request(url, headers={"User-Agent": "Mozilla/5.0"})
with urlopen(request) as response, open(path, "wb") as file:
shutil.copyfileobj(response, file)
def extract_zip_safely(zip_path, target_dir="."):
target_dir = Path(target_dir).resolve()
with zipfile.ZipFile(zip_path, "r") as zip_ref:
for member in zip_ref.infolist():
target_path = (target_dir / member.filename).resolve()
if not str(target_path).startswith(str(target_dir)):
raise RuntimeError(f"Unsafe zip path: {member.filename}")
zip_ref.extractall(target_dir)
for manifest_url in RESOURCE_MANIFESTS:
name = manifest_url.rstrip("/").split("/")[-2]
manifest_path = Path(f"{name}-latest.json")
archive_path = Path(f"{name}.zip")
download_file(manifest_url, manifest_path)
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
expected_paths = [Path(path) for path in manifest.get("expected_paths", [])]
if not all(path.exists() for path in expected_paths):
download_file(manifest["archive_url"], archive_path)
extract_zip_safely(archive_path, manifest.get("extract_to", "."))
missing = [str(path) for path in expected_paths if not path.exists()]
if missing:
raise FileNotFoundError(f"Missing expected paths: {missing}")
print(f"{manifest['name']} ready.")Imports
Libraries for blood cell object detection and classification with RetinaNet on BCCD.
import copy
import os
import random
import subprocess
import sys
import time
import xml.etree.ElementTree as ET
from collections import defaultdict
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import functional as TF
from torchvision.ops import box_iou
from torchvision.models import (
ResNet18_Weights,
ResNet34_Weights,
ResNet50_Weights,
ResNet101_Weights,
ResNet152_Weights,
resnet18,
resnet34,
resnet50,
resnet101,
resnet152,
)
from torchvision.models.detection import (
RetinaNet,
RetinaNet_ResNet50_FPN_Weights,
retinanet_resnet50_fpn,
)
from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor
from torchvision.ops.feature_pyramid_network import LastLevelP6P7
try:
import wandb
except ImportError:
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'wandb'])
import wandb
print('PyTorch version:', torch.__version__)
print('Wandb version:', wandb.__version__)Hyperparameters
Configure paths, training settings, and pretrained-vs-scratch head options.
hyperparameters = {
'dataset_roots': [
'./datasets/BCCD',
],
'batch_size': 4,
'num_epochs': 20,
'learning_rate': 3e-4,
'weight_decay': 1e-4,
'num_workers': 4,
'persistent_workers': True,
'random_seed': 42,
# GPU selection (default: first available GPU)
'gpu_id': 0,
# RetinaNet initialization options
'backbone_name': 'resnet50',
'use_pretrained_backbone': True,
'use_pretrained_retinanet_head': True,
'train_head_from_scratch': False,
# Detection settings
'score_threshold_eval': 0.05,
'score_threshold_vis': 0.25,
'nms_iou_threshold': 0.5,
'max_detections_per_image': 200,
'map_iou_threshold': 0.5,
# Scheduler settings
'scheduler_use_restarts': True,
'scheduler_period_epochs': 7, # Used only if scheduler_use_restarts=True
# Weights & Biases settings
'use_wandb': True,
'wandb_project': 'deep-learning-2026-retinanet',
'wandb_run_name': 'retinanet-resnet50-bccd-baseline-noaug-restart',
'wandb_entity': None, # Set your W&B username/team if needed, else keep None
'wandb_log_model': False,
'wandb_watch_model': True,
}
seed = hyperparameters['random_seed']
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
requested_gpu_id = hyperparameters.get('gpu_id', 0)
if torch.cuda.is_available():
n_gpus = torch.cuda.device_count()
if requested_gpu_id < n_gpus:
device = torch.device(f'cuda:{requested_gpu_id}')
else:
print(f"Requested gpu_id={requested_gpu_id}, but only {n_gpus} GPU(s) available. Falling back to cuda:0.")
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
print('Using device:', device)
print('Random seed:', seed)
if hyperparameters.get('use_wandb', True):
os.environ['WANDB_API_KEY'] = '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'
print('W&B logging is enabled. Run name:', hyperparameters['wandb_run_name'])
else:
print('W&B logging is disabled.')
def is_bccd_root(p):
return (p / 'Annotations').exists() and (p / 'JPEGImages').exists() and (p / 'ImageSets').exists()
def resolve_bccd_root(roots):
for root in roots:
p = Path(root)
if is_bccd_root(p):
return p
# Fallback: search from current working directory.
cwd = Path.cwd()
for candidate in cwd.rglob('*'):
if candidate.is_dir() and candidate.name in {'BCCD', 'BCCD_Dataset'}:
if is_bccd_root(candidate):
return candidate
nested = candidate / 'BCCD'
if is_bccd_root(nested):
return nested
raise FileNotFoundError(
f"Could not find BCCD root in any of: {roots}. Searched recursively from {Path.cwd()} as fallback."
)
def find_split_file(root, split_name):
candidates = [
root / 'ImageSets' / f'{split_name}.txt',
root / 'ImageSets' / 'Main' / f'{split_name}.txt',
]
for p in candidates:
if p.exists():
return p
raise FileNotFoundError(f"Could not find split file for '{split_name}' in {root / 'ImageSets'}")
dataset_root = resolve_bccd_root(hyperparameters['dataset_roots'])
annotations_root = dataset_root / 'Annotations'
images_root = dataset_root / 'JPEGImages'
split_files = {
'train': find_split_file(dataset_root, 'train'),
'val': find_split_file(dataset_root, 'val'),
'test': find_split_file(dataset_root, 'test'),
}
print('Using BCCD root:', dataset_root.resolve())
print('Split files:')
for split_name, split_path in split_files.items():
print(f" {split_name}: {split_path.resolve()}")RetinaNet Model Setup
Build RetinaNet with a selectable ResNet-FPN backbone and configurable head initialization.
def reset_detection_heads(model, num_classes, train_head_from_scratch=False):
cls_head = model.head.classification_head
reg_head = model.head.regression_head
num_anchors = cls_head.num_anchors
in_channels = cls_head.cls_logits.in_channels
old_cls_logits = cls_head.cls_logits
old_num_classes = old_cls_logits.out_channels // max(1, num_anchors)
# Keep the pretrained conv tower and replace only class logits.
# When possible, warm-start logits from pretrained weights to stabilize epoch-1 mAP.
new_cls_logits = nn.Conv2d(
in_channels,
num_anchors * num_classes,
kernel_size=3,
stride=1,
padding=1,
)
if not train_head_from_scratch and old_cls_logits.weight.shape[1] == in_channels:
with torch.no_grad():
old_w = old_cls_logits.weight.data.view(num_anchors, old_num_classes, in_channels, 3, 3)
old_b = old_cls_logits.bias.data.view(num_anchors, old_num_classes)
if num_classes == old_num_classes:
new_w = old_w
new_b = old_b
else:
# Mean across old classes per anchor gives a stable prior for new classes.
proto_w = old_w.mean(dim=1, keepdim=True)
proto_b = old_b.mean(dim=1, keepdim=True)
new_w = proto_w.repeat(1, num_classes, 1, 1, 1)
new_b = proto_b.repeat(1, num_classes)
new_cls_logits.weight.copy_(new_w.reshape(num_anchors * num_classes, in_channels, 3, 3))
new_cls_logits.bias.copy_(new_b.reshape(num_anchors * num_classes))
else:
nn.init.normal_(new_cls_logits.weight, std=0.01)
nn.init.constant_(new_cls_logits.bias, -np.log((1 - 0.01) / 0.01))
cls_head.cls_logits = new_cls_logits
cls_head.num_classes = num_classes
if train_head_from_scratch:
for module in cls_head.modules():
if isinstance(module, nn.Conv2d):
nn.init.normal_(module.weight, std=0.01)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
for module in reg_head.modules():
if isinstance(module, nn.Conv2d):
nn.init.normal_(module.weight, std=0.01)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
# Focal loss bias initialization after generic reinit
nn.init.constant_(cls_head.cls_logits.bias, -np.log((1 - 0.01) / 0.01))
return model
RESNET_BUILDERS = {
'resnet18': resnet18,
'resnet34': resnet34,
'resnet50': resnet50,
'resnet101': resnet101,
'resnet152': resnet152,
}
RESNET_WEIGHTS = {
'resnet18': ResNet18_Weights,
'resnet34': ResNet34_Weights,
'resnet50': ResNet50_Weights,
'resnet101': ResNet101_Weights,
'resnet152': ResNet152_Weights,
}
def build_model(hparams, num_classes):
use_pretrained_backbone = hparams.get('use_pretrained_backbone', True)
use_pretrained_head = hparams.get('use_pretrained_retinanet_head', False)
train_head_from_scratch = hparams.get('train_head_from_scratch', False)
backbone_name = str(hparams.get('backbone_name', 'resnet50')).lower()
if backbone_name not in RESNET_BUILDERS:
raise ValueError(
f"Unsupported backbone_name='{backbone_name}'. Supported: {sorted(RESNET_BUILDERS.keys())}"
)
if use_pretrained_head and backbone_name == 'resnet50':
if not use_pretrained_backbone:
print('use_pretrained_retinanet_head=True requires pretrained backbone weights. Ignoring use_pretrained_backbone=False.')
model = retinanet_resnet50_fpn(
weights=RetinaNet_ResNet50_FPN_Weights.DEFAULT,
)
model = reset_detection_heads(
model,
num_classes=num_classes,
train_head_from_scratch=train_head_from_scratch,
)
else:
if use_pretrained_head and backbone_name != 'resnet50':
print('Pretrained RetinaNet head is only available for backbone_name=resnet50. Building custom head instead.')
backbone_weights = RESNET_WEIGHTS[backbone_name].DEFAULT if use_pretrained_backbone else None
backbone_body = RESNET_BUILDERS[backbone_name](weights=backbone_weights)
# For LastLevelP6P7, the input channels match the C5 stage output channels.
c5_channels = backbone_body.fc.in_features
backbone = _resnet_fpn_extractor(
backbone_body,
trainable_layers=5,
returned_layers=[2, 3, 4],
extra_blocks=LastLevelP6P7(c5_channels, 256),
)
model = RetinaNet(
backbone=backbone,
num_classes=num_classes,
)
model.transform.min_size = (800,)
model.transform.max_size = 1333
model.score_thresh = hparams.get('score_threshold_eval', 0.05)
model.nms_thresh = hparams.get('nms_iou_threshold', 0.5)
model.detections_per_img = hparams.get('max_detections_per_image', 200)
return model
def count_trainable_parameters(model):
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total, trainableOptimization Setup
Define optimizer and scheduler for RetinaNet fine-tuning.
def build_optimizer_and_scheduler(model, hparams):
optimizer = optim.AdamW(
model.parameters(),
lr=hparams['learning_rate'],
weight_decay=hparams.get('weight_decay', 0.0),
)
use_restarts = bool(hparams.get('scheduler_use_restarts', True))
period_epochs = max(1, int(hparams.get('scheduler_period_epochs', 5)))
total_epochs = max(1, int(hparams.get('num_epochs', 1)))
if use_restarts:
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=period_epochs,
T_mult=1,
)
else:
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=total_epochs,
)
return optimizer, schedulerModel Summary and Trainable Parameters
Inspect the model and verify how many parameters are trainable under the selected transfer mode.
_model = globals().get('model', None)
if _model is None:
print('Model not found yet. Run the dataset/dataloader cell first, then re-run this cell.')
else:
total_params, trainable_params = count_trainable_parameters(_model)
backbone_name = str(hyperparameters.get('backbone_name', 'resnet50')).lower()
print(f"Model: RetinaNet {backbone_name.replace('resnet', 'ResNet')}-FPN")
print('use_pretrained_backbone:', hyperparameters['use_pretrained_backbone'])
print('use_pretrained_retinanet_head:', hyperparameters['use_pretrained_retinanet_head'])
print('train_head_from_scratch:', hyperparameters['train_head_from_scratch'])
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")Build BCCD Splits
Read predefined train/val/test splits and parse VOC XML annotations.
def read_split_ids(split_file):
image_ids = []
for line in split_file.read_text().splitlines():
item = line.strip()
if not item:
continue
stem = Path(item).stem
image_ids.append(stem)
return image_ids
def parse_voc_annotation(xml_path):
root = ET.parse(xml_path).getroot()
boxes = []
labels = []
for obj in root.findall('object'):
cls_name = obj.findtext('name')
if cls_name is None:
continue
bnd = obj.find('bndbox')
if bnd is None:
continue
xmin = float(bnd.findtext('xmin', default='0'))
ymin = float(bnd.findtext('ymin', default='0'))
xmax = float(bnd.findtext('xmax', default='0'))
ymax = float(bnd.findtext('ymax', default='0'))
# VOC coordinates are usually 1-based inclusive; convert to 0-based half-open style.
xmin = max(0.0, xmin - 1.0)
ymin = max(0.0, ymin - 1.0)
xmax = max(xmin + 1.0, xmax)
ymax = max(ymin + 1.0, ymax)
boxes.append([xmin, ymin, xmax, ymax])
labels.append(cls_name)
return boxes, labels
split_ids = {name: read_split_ids(path) for name, path in split_files.items()}
print('Split sizes:', {k: len(v) for k, v in split_ids.items()})
def discover_labels(split_id_dict, ann_root):
labels = set()
for image_ids in split_id_dict.values():
for image_id in image_ids:
xml_path = ann_root / f'{image_id}.xml'
if not xml_path.exists():
continue
_, class_names = parse_voc_annotation(xml_path)
labels.update(class_names)
preferred = ['RBC', 'WBC', 'Platelets']
if set(preferred).issubset(labels):
return preferred
return sorted(labels)
class_names = discover_labels(split_ids, annotations_root)
if len(class_names) != 3:
print(f'Warning: discovered {len(class_names)} classes: {class_names}')
label_to_id = {name: i + 1 for i, name in enumerate(class_names)}
id_to_label = {i: name for name, i in label_to_id.items()}
num_classes = len(class_names) + 1 # include background class 0
print('Class mapping (background=0):')
for cls_name, cls_id in label_to_id.items():
print(f' {cls_id}: {cls_name}')
class BCCDDataset(Dataset):
def __init__(self, image_ids, images_root, annotations_root, label_to_id):
self.image_ids = list(image_ids)
self.images_root = Path(images_root)
self.annotations_root = Path(annotations_root)
self.label_to_id = dict(label_to_id)
def __len__(self):
return len(self.image_ids)
def _resolve_image_path(self, image_id):
for ext in ['.jpg', '.jpeg', '.png', '.bmp']:
p = self.images_root / f'{image_id}{ext}'
if p.exists():
return p
raise FileNotFoundError(f'No image found for ID {image_id} in {self.images_root}')
def __getitem__(self, idx):
image_id = self.image_ids[idx]
image_path = self._resolve_image_path(image_id)
xml_path = self.annotations_root / f'{image_id}.xml'
image = Image.open(image_path).convert('RGB')
width, height = image.size
boxes, class_names = parse_voc_annotation(xml_path)
labels = [self.label_to_id[n] for n in class_names if n in self.label_to_id]
if len(boxes) == 0:
boxes_tensor = torch.zeros((0, 4), dtype=torch.float32)
labels_tensor = torch.zeros((0,), dtype=torch.int64)
else:
boxes_tensor = torch.tensor(boxes, dtype=torch.float32)
labels_tensor = torch.tensor(labels, dtype=torch.int64)
boxes_tensor[:, 0::2] = boxes_tensor[:, 0::2].clamp(0, width - 1)
boxes_tensor[:, 1::2] = boxes_tensor[:, 1::2].clamp(0, height - 1)
image_np = np.array(image)
image_tensor = TF.to_tensor(image_np)
if boxes_tensor.numel() > 0:
boxes_tensor = boxes_tensor.reshape(-1, 4)
boxes_tensor[:, 0::2] = boxes_tensor[:, 0::2].clamp(0, width - 1)
boxes_tensor[:, 1::2] = boxes_tensor[:, 1::2].clamp(0, height - 1)
wh = boxes_tensor[:, 2:] - boxes_tensor[:, :2]
valid = (wh[:, 0] > 1.0) & (wh[:, 1] > 1.0)
boxes_tensor = boxes_tensor[valid]
labels_tensor = labels_tensor[valid]
area = (boxes_tensor[:, 2] - boxes_tensor[:, 0]) * (boxes_tensor[:, 3] - boxes_tensor[:, 1])
iscrowd = torch.zeros((boxes_tensor.shape[0],), dtype=torch.int64)
target = {
'boxes': boxes_tensor,
'labels': labels_tensor,
'image_id': torch.tensor([idx], dtype=torch.int64),
'area': area,
'iscrowd': iscrowd,
'image_name': image_id,
'size': torch.tensor([height, width], dtype=torch.int64),
}
return image_tensor, target
def detection_collate_fn(batch):
images, targets = zip(*batch)
return list(images), list(targets)Annotation Sanity Check
Visualize sample images with ground-truth bounding boxes and class labels.
def draw_boxes(ax, image_np, boxes, labels, class_map, color='lime', linewidth=2, alpha=1.0):
ax.imshow(image_np)
ax.axis('off')
for box, label in zip(boxes, labels):
x1, y1, x2, y2 = box
w, h = x2 - x1, y2 - y1
rect = patches.Rectangle(
(x1, y1), w, h, linewidth=linewidth, edgecolor=color, facecolor='none', alpha=alpha
)
ax.add_patch(rect)
txt = class_map.get(int(label), str(int(label)))
ax.text(
x1,
max(0, y1 - 3),
txt,
color='white',
fontsize=9,
bbox=dict(facecolor=color, alpha=0.7, pad=1.5, edgecolor='none'),
)
def show_ground_truth_samples(dataset, max_items=6, cols=3):
n = min(max_items, len(dataset))
rows = int(np.ceil(n / cols))
fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows))
axes = np.array(axes).reshape(rows, cols)
for i in range(rows * cols):
ax = axes[i // cols, i % cols]
if i >= n:
ax.axis('off')
continue
image, target = dataset[i]
img_np = image.permute(1, 2, 0).numpy()
boxes = target['boxes'].numpy()
labels = target['labels'].numpy()
draw_boxes(ax, img_np, boxes, labels, id_to_label, color='deepskyblue')
ax.set_title(f"{target['image_name']} ({len(labels)} objects)")
plt.tight_layout()
plt.show()
_train_vis_dataset = BCCDDataset(
split_ids['train'],
images_root=images_root,
annotations_root=annotations_root,
label_to_id=label_to_id,
)
show_ground_truth_samples(_train_vis_dataset, max_items=6, cols=3)Create Detection Datasets
Instantiate train/val/test datasets using predefined ImageSets splits.
train_dataset = BCCDDataset(
split_ids['train'],
images_root=images_root,
annotations_root=annotations_root,
label_to_id=label_to_id,
)
train_eval_dataset = BCCDDataset(
split_ids['train'],
images_root=images_root,
annotations_root=annotations_root,
label_to_id=label_to_id,
)
val_dataset = BCCDDataset(
split_ids['val'],
images_root=images_root,
annotations_root=annotations_root,
label_to_id=label_to_id,
)
test_dataset = BCCDDataset(
split_ids['test'],
images_root=images_root,
annotations_root=annotations_root,
label_to_id=label_to_id,
)
print('Datasets ready:')
print(' train:', len(train_dataset))
print(' train_eval:', len(train_eval_dataset))
print(' val:', len(val_dataset))
print(' test:', len(test_dataset))DataLoaders And Model
Build detection dataloaders and instantiate RetinaNet.
loader_num_workers = hyperparameters.get('num_workers', 2)
use_persistent_workers = hyperparameters.get('persistent_workers', True) and loader_num_workers > 0
pin_memory = torch.cuda.is_available()
train_loader = DataLoader(
train_dataset,
batch_size=hyperparameters['batch_size'],
shuffle=True,
num_workers=loader_num_workers,
pin_memory=pin_memory,
persistent_workers=use_persistent_workers,
collate_fn=detection_collate_fn,
)
train_eval_loader = DataLoader(
train_eval_dataset,
batch_size=hyperparameters['batch_size'],
shuffle=False,
num_workers=loader_num_workers,
pin_memory=pin_memory,
persistent_workers=use_persistent_workers,
collate_fn=detection_collate_fn,
)
val_loader = DataLoader(
val_dataset,
batch_size=hyperparameters['batch_size'],
shuffle=False,
num_workers=loader_num_workers,
pin_memory=pin_memory,
persistent_workers=use_persistent_workers,
collate_fn=detection_collate_fn,
)
test_loader = DataLoader(
test_dataset,
batch_size=hyperparameters['batch_size'],
shuffle=False,
num_workers=loader_num_workers,
pin_memory=pin_memory,
persistent_workers=use_persistent_workers,
collate_fn=detection_collate_fn,
)
model = build_model(hyperparameters, num_classes=num_classes).to(device)
optimizer, scheduler = build_optimizer_and_scheduler(model, hyperparameters)
total_params, trainable_params = count_trainable_parameters(model)
backbone_name = str(hyperparameters.get('backbone_name', 'resnet50')).lower()
print(f"Model: RetinaNet {backbone_name.replace('resnet', 'ResNet')}-FPN")
print(f"Trainable parameters: {trainable_params:,} / {total_params:,}")Batch Visualization
Inspect one training batch with detection targets.
images, targets = next(iter(train_loader))
print('Batch size:', len(images))
print('Example image tensor shape:', tuple(images[0].shape))
print('Example targets keys:', list(targets[0].keys()))
n_show = min(4, len(images))
fig, axes = plt.subplots(1, n_show, figsize=(5 * n_show, 5))
if n_show == 1:
axes = [axes]
for i in range(n_show):
img_np = images[i].permute(1, 2, 0).numpy()
boxes = targets[i]['boxes'].numpy()
labels = targets[i]['labels'].numpy()
draw_boxes(axes[i], img_np, boxes, labels, id_to_label, color='gold')
axes[i].set_title(f"{targets[i]['image_name']} ({len(labels)} objects)")
plt.tight_layout()
plt.show()Training Step
One training epoch for RetinaNet using detection losses.
def train_one_epoch(model, train_loader, optimizer, device):
model.train()
running_loss = 0.0
running_cls_loss = 0.0
running_reg_loss = 0.0
n_batches = 0
for images, targets in train_loader:
images = [img.to(device, non_blocking=True) for img in images]
targets_device = []
for t in targets:
targets_device.append({
'boxes': t['boxes'].to(device, non_blocking=True),
'labels': t['labels'].to(device, non_blocking=True),
'image_id': t['image_id'].to(device, non_blocking=True),
'area': t['area'].to(device, non_blocking=True),
'iscrowd': t['iscrowd'].to(device, non_blocking=True),
})
loss_dict = model(images, targets_device)
loss = sum(loss_dict.values())
cls_loss = loss_dict.get('classification')
if cls_loss is None:
cls_candidates = [v for k, v in loss_dict.items() if ('class' in k.lower() or 'cls' in k.lower())]
cls_loss = sum(cls_candidates) if cls_candidates else torch.tensor(0.0, device=loss.device)
reg_loss = loss_dict.get('bbox_regression')
if reg_loss is None:
reg_candidates = [v for k, v in loss_dict.items() if ('reg' in k.lower() or 'bbox' in k.lower() or 'box' in k.lower())]
reg_loss = sum(reg_candidates) if reg_candidates else torch.tensor(0.0, device=loss.device)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
running_loss += float(loss.item())
running_cls_loss += float(cls_loss.item())
running_reg_loss += float(reg_loss.item())
n_batches += 1
avg_loss = running_loss / max(1, n_batches)
avg_cls_loss = running_cls_loss / max(1, n_batches)
avg_reg_loss = running_reg_loss / max(1, n_batches)
return avg_loss, avg_cls_loss, avg_reg_lossmAP Evaluation
Compute VOC-style AP per class and mAP at IoU=0.5.
def compute_ap_from_detections(scores, is_tp, total_gt):
if total_gt == 0:
return np.nan
if len(scores) == 0:
return 0.0
order = np.argsort(-np.asarray(scores))
tp = np.asarray(is_tp)[order].astype(np.float32)
fp = 1.0 - tp
tp_cum = np.cumsum(tp)
fp_cum = np.cumsum(fp)
recall = tp_cum / max(1.0, float(total_gt))
precision = tp_cum / np.maximum(tp_cum + fp_cum, 1e-8)
mrec = np.concatenate(([0.0], recall, [1.0]))
mpre = np.concatenate(([0.0], precision, [0.0]))
for i in range(mpre.size - 1, 0, -1):
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
idx = np.where(mrec[1:] != mrec[:-1])[0]
ap = np.sum((mrec[idx + 1] - mrec[idx]) * mpre[idx + 1])
return float(ap)
@torch.no_grad()
def evaluate_map(model, data_loader, class_ids, iou_threshold=0.5, score_threshold=0.05):
model.eval()
per_class_scores = {c: [] for c in class_ids}
per_class_tp = {c: [] for c in class_ids}
per_class_total_gt = {c: 0 for c in class_ids}
for images, targets in data_loader:
images_device = [img.to(device, non_blocking=True) for img in images]
outputs = model(images_device)
for output, target in zip(outputs, targets):
gt_boxes = target['boxes']
gt_labels = target['labels']
pred_boxes = output['boxes'].detach().cpu()
pred_scores = output['scores'].detach().cpu()
pred_labels = output['labels'].detach().cpu()
keep = pred_scores >= score_threshold
pred_boxes = pred_boxes[keep]
pred_scores = pred_scores[keep]
pred_labels = pred_labels[keep]
for cls_id in class_ids:
gt_mask = gt_labels == cls_id
cls_gt_boxes = gt_boxes[gt_mask]
per_class_total_gt[cls_id] += int(cls_gt_boxes.shape[0])
pred_mask = pred_labels == cls_id
cls_pred_boxes = pred_boxes[pred_mask]
cls_pred_scores = pred_scores[pred_mask]
if cls_pred_scores.numel() == 0:
continue
order = torch.argsort(cls_pred_scores, descending=True)
cls_pred_boxes = cls_pred_boxes[order]
cls_pred_scores = cls_pred_scores[order]
matched_gt = torch.zeros((cls_gt_boxes.shape[0],), dtype=torch.bool)
for p_box, p_score in zip(cls_pred_boxes, cls_pred_scores):
if cls_gt_boxes.shape[0] == 0:
per_class_scores[cls_id].append(float(p_score))
per_class_tp[cls_id].append(0.0)
continue
ious = box_iou(p_box.unsqueeze(0), cls_gt_boxes).squeeze(0)
best_iou, best_idx = torch.max(ious, dim=0)
if best_iou >= iou_threshold and not matched_gt[best_idx]:
matched_gt[best_idx] = True
per_class_scores[cls_id].append(float(p_score))
per_class_tp[cls_id].append(1.0)
else:
per_class_scores[cls_id].append(float(p_score))
per_class_tp[cls_id].append(0.0)
ap_per_class = {}
for cls_id in class_ids:
ap_per_class[cls_id] = compute_ap_from_detections(
per_class_scores[cls_id],
per_class_tp[cls_id],
per_class_total_gt[cls_id],
)
valid_aps = [v for v in ap_per_class.values() if not np.isnan(v)]
map_value = float(np.mean(valid_aps)) if len(valid_aps) > 0 else float('nan')
return map_value, ap_per_classTrain And Test
Train RetinaNet, monitor mAP on train/val every epoch, and report final test mAP.
train_losses = []
train_maps = []
val_maps = []
best_val_map = -1.0
best_epoch = 0
best_model_state = copy.deepcopy(model.state_dict())
class_ids = sorted(id_to_label.keys())
wandb_run = None
if hyperparameters.get('use_wandb', True):
wandb_run = wandb.init(
project=hyperparameters.get('wandb_project', 'deep-learning-2026-retinanet'),
entity=hyperparameters.get('wandb_entity', None),
name=hyperparameters.get('wandb_run_name', None),
config={
**hyperparameters,
'num_classes': num_classes,
'class_names': class_names,
'device': str(device),
'dataset_root': str(dataset_root),
'train_size': len(train_dataset),
'val_size': len(val_dataset),
'test_size': len(test_dataset),
},
)
if hyperparameters.get('wandb_watch_model', True):
wandb.watch(model, log='all', log_freq=100)
for epoch in range(hyperparameters['num_epochs']):
start = time.time()
train_loss, train_cls_loss, train_reg_loss = train_one_epoch(model, train_loader, optimizer, device)
scheduler.step()
train_map, train_ap = evaluate_map(
model,
train_eval_loader,
class_ids=class_ids,
iou_threshold=hyperparameters['map_iou_threshold'],
score_threshold=hyperparameters['score_threshold_eval'],
)
val_map, val_ap = evaluate_map(
model,
val_loader,
class_ids=class_ids,
iou_threshold=hyperparameters['map_iou_threshold'],
score_threshold=hyperparameters['score_threshold_eval'],
)
train_losses.append(train_loss)
train_maps.append(train_map)
val_maps.append(val_map)
elapsed = time.time() - start
lr = optimizer.param_groups[0]['lr']
print(
f"Epoch {epoch + 1:02d}/{hyperparameters['num_epochs']} | "
f"{elapsed:.1f}s | lr={lr:.2e} | loss={train_loss:.4f} (cls={train_cls_loss:.4f}, reg={train_reg_loss:.4f}) | "
f"train mAP@{hyperparameters['map_iou_threshold']:.2f}={train_map:.4f} | "
f"val mAP@{hyperparameters['map_iou_threshold']:.2f}={val_map:.4f}"
)
if val_map > best_val_map:
best_val_map = val_map
best_epoch = epoch + 1
best_model_state = copy.deepcopy(model.state_dict())
if wandb_run is not None:
epoch_log = {
'epoch': epoch + 1,
'train/loss': float(train_loss),
'train/loss_classification': float(train_cls_loss),
'train/loss_regression': float(train_reg_loss),
'train/map': float(train_map),
'val/map': float(val_map),
'train/val_gap_map': float(train_map - val_map),
'train/lr': float(lr),
'train/epoch_time_sec': float(elapsed),
'best/val_map_so_far': float(best_val_map),
'best/epoch_so_far': int(best_epoch),
}
for cls_id in class_ids:
cls_name = id_to_label[cls_id]
epoch_log[f'train_ap/{cls_name}'] = float(train_ap[cls_id])
epoch_log[f'val_ap/{cls_name}'] = float(val_ap[cls_id])
wandb.log(epoch_log, step=epoch + 1)
model.load_state_dict(best_model_state)
test_map, test_ap = evaluate_map(
model,
test_loader,
class_ids=class_ids,
iou_threshold=hyperparameters['map_iou_threshold'],
score_threshold=hyperparameters['score_threshold_eval'],
)
print()
print(f"Best validation mAP@{hyperparameters['map_iou_threshold']:.2f}: {best_val_map:.4f} at epoch {best_epoch}")
print(f"Test mAP@{hyperparameters['map_iou_threshold']:.2f}: {test_map:.4f}")
print('Test AP per class:')
for cls_id in class_ids:
print(f" {id_to_label[cls_id]}: {test_ap[cls_id]:.4f}")
if wandb_run is not None:
final_log = {
'best/epoch': int(best_epoch),
'best/val_map': float(best_val_map),
'test/map': float(test_map),
}
for cls_id in class_ids:
cls_name = id_to_label[cls_id]
final_log[f'test_ap/{cls_name}'] = float(test_ap[cls_id])
wandb.log(final_log, step=hyperparameters['num_epochs'])
if hyperparameters.get('wandb_log_model', False):
model_artifact = wandb.Artifact(
name=f"{wandb.run.name}-weights",
type='model',
description='Best RetinaNet weights (by validation mAP)',
)
weights_path = Path('retinanet_best_model.pth')
torch.save(best_model_state, weights_path)
model_artifact.add_file(str(weights_path))
wandb.log_artifact(model_artifact)
epochs = np.arange(1, len(train_losses) + 1)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs, train_losses, marker='o', label='Train loss')
plt.title('RetinaNet training loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(alpha=0.25)
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(epochs, train_maps, marker='o', label='Train mAP')
plt.plot(epochs, val_maps, marker='o', label='Validation mAP')
plt.title(f"mAP@{hyperparameters['map_iou_threshold']:.2f} per epoch")
plt.xlabel('Epoch')
plt.ylabel('mAP')
plt.ylim(0, 1.0)
plt.grid(alpha=0.25)
plt.legend()
plt.tight_layout()
plt.show()
if wandb_run is not None:
wandb.finish()Qualitative Results
Object-detection style visualization of predictions vs ground truth.
def draw_prediction_boxes(ax, image_np, boxes, labels, scores, class_map, score_thr=0.25):
ax.imshow(image_np)
ax.axis('off')
for box, label, score in zip(boxes, labels, scores):
if score < score_thr:
continue
x1, y1, x2, y2 = box
w, h = x2 - x1, y2 - y1
rect = patches.Rectangle(
(x1, y1), w, h, linewidth=2, edgecolor='crimson', facecolor='none'
)
ax.add_patch(rect)
cls_name = class_map.get(int(label), str(int(label)))
ax.text(
x1,
max(0, y1 - 3),
f"{cls_name} {score:.2f}",
color='white',
fontsize=9,
bbox=dict(facecolor='crimson', alpha=0.75, edgecolor='none', pad=1.5),
)
@torch.no_grad()
def show_detection_results(model, data_loader, max_items=4, score_thr=0.25):
model.eval()
shown = 0
fig, axes = plt.subplots(max_items, 2, figsize=(12, 5 * max_items))
if max_items == 1:
axes = np.array([axes])
for images, targets in data_loader:
images_device = [img.to(device) for img in images]
outputs = model(images_device)
for image, target, output in zip(images, targets, outputs):
if shown >= max_items:
break
image_np = image.permute(1, 2, 0).numpy()
gt_ax = axes[shown, 0]
pred_ax = axes[shown, 1]
draw_boxes(
gt_ax,
image_np,
target['boxes'].numpy(),
target['labels'].numpy(),
id_to_label,
color='dodgerblue',
)
gt_ax.set_title(f"GT: {target['image_name']}")
draw_prediction_boxes(
pred_ax,
image_np,
output['boxes'].detach().cpu().numpy(),
output['labels'].detach().cpu().numpy(),
output['scores'].detach().cpu().numpy(),
id_to_label,
score_thr=score_thr,
)
pred_ax.set_title(f"Prediction (score >= {score_thr:.2f})")
shown += 1
if shown >= max_items:
break
plt.tight_layout()
plt.show()
show_detection_results(
model,
test_loader,
max_items=4,
score_thr=hyperparameters['score_threshold_vis'],
)
# Visual comparison at different confidence thresholds
for th in [0.15, 0.35, 0.55]:
print(f'Visualization threshold: {th:.2f}')
show_detection_results(model, test_loader, max_items=2, score_thr=th)
# Compact AP summary on train/val/test with the best model loaded
train_map_best, train_ap_best = evaluate_map(
model,
train_eval_loader,
class_ids=class_ids,
iou_threshold=hyperparameters['map_iou_threshold'],
score_threshold=hyperparameters['score_threshold_eval'],
)
val_map_best, val_ap_best = evaluate_map(
model,
val_loader,
class_ids=class_ids,
iou_threshold=hyperparameters['map_iou_threshold'],
score_threshold=hyperparameters['score_threshold_eval'],
)
print(f"Train mAP@{hyperparameters['map_iou_threshold']:.2f}: {train_map_best:.4f}")
for cls_id in class_ids:
print(f" Train AP {id_to_label[cls_id]}: {train_ap_best[cls_id]:.4f}")
print(f"Val mAP@{hyperparameters['map_iou_threshold']:.2f}: {val_map_best:.4f}")
for cls_id in class_ids:
print(f" Val AP {id_to_label[cls_id]}: {val_ap_best[cls_id]:.4f}")
print(f"Test mAP@{hyperparameters['map_iou_threshold']:.2f}: {test_map:.4f}")
for cls_id in class_ids:
print(f" Test AP {id_to_label[cls_id]}: {test_ap[cls_id]:.4f}")
# Final qualitative panel
show_detection_results(
model,
test_loader,
max_items=6,
score_thr=hyperparameters['score_threshold_vis'],
)