mednet.engine.classify.predictor¶
Prediction engine for classification tasks.
Functions
|
Run inference on input data, output predictions. |
- mednet.engine.classify.predictor.run(model, datamodule, device_manager)[source]¶
Run inference on input data, output predictions.
- Parameters:
model (
LightningModule
) – Neural network model (e.g. pasa).datamodule (
LightningDataModule
) – The lightning DataModule to run predictions on.device_manager (
DeviceManager
) – An internal device representation, to be used for training and validation. This representation can be converted into a pytorch device or a lightning accelerator setup.
- Return type:
Union
[list
[tuple
[str
,Sequence
[int
],Sequence
[float
]]],list
[list
[tuple
[str
,Sequence
[int
],Sequence
[float
]]]],Mapping
[str
,Sequence
[tuple
[str
,Sequence
[int
],Sequence
[float
]]]],None
]- Returns:
Depending on the return type of the DataModule’s
predict_dataloader()
method:if
torch.utils.data.DataLoader
, then returns alist
of predictions.if
list
oftorch.utils.data.DataLoader
, then returns a list of lists of predictions, each list corresponding to the iteration over one of the dataloaders.if
dict
ofstr
totorch.utils.data.DataLoader
, then returns a dictionary mapping names to lists of predictions.if
None
, then returnsNone
.
- Raises:
TypeError – If the DataModule’s
predict_dataloader()
method does not return any of the types described above.