mednet.config.detect.models.faster_rcnn

Faster R-CNN object detection (and classification) network architecture, from [RHGS17].

# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

"""Faster R-CNN object detection (and classification) network architecture, from :cite:p:`ren_faster_2017`."""

import torch.optim
import torchvision.transforms
import torchvision.transforms.v2

import mednet.models.detect.faster_rcnn
import mednet.models.losses
import mednet.models.transforms

model = mednet.models.detect.faster_rcnn.FasterRCNN(
    optimizer_type=torch.optim.SGD,
    optimizer_arguments=dict(lr=0.005, momentum=0.9, weight_decay=0.0005),
    scheduler_type=torch.optim.lr_scheduler.StepLR,
    scheduler_arguments=dict(step_size=3, gamma=0.1),
    model_transforms=[
        mednet.models.transforms.SquareCenterPad(),
        torchvision.transforms.v2.Resize(512, antialias=True),
        torchvision.transforms.v2.RGB(),
    ],
    pretrained=True,
    num_classes=1,
    variant="mobilenetv3-small",  # fastest testing
)