"""This module contains utility classes and functions."""
import enum
import os
import typing as t
import warnings
import numpy as np
import pymia.data.conversion as conversion
import pymia.filtering.filter as fltr
import pymia.evaluation.evaluator as eval_
import pymia.evaluation.metric as metric
import SimpleITK as sitk
import mialab.data.structure as structure
import mialab.filtering.feature_extraction as fltr_feat
import mialab.filtering.postprocessing as fltr_postp
import mialab.filtering.preprocessing as fltr_prep
import mialab.utilities.multi_processor as mproc
atlas_t1 = sitk.Image()
atlas_t2 = sitk.Image()
[docs]def load_atlas_images(directory: str):
"""Loads the T1 and T2 atlas images.
Args:
directory (str): The atlas data directory.
"""
global atlas_t1
global atlas_t2
atlas_t1 = sitk.ReadImage(os.path.join(directory, 'mni_icbm152_t1_tal_nlin_sym_09a_mask.nii.gz'))
atlas_t2 = sitk.ReadImage(os.path.join(directory, 'mni_icbm152_t2_tal_nlin_sym_09a.nii.gz'))
if not conversion.ImageProperties(atlas_t1) == conversion.ImageProperties(atlas_t2):
raise ValueError('T1w and T2w atlas images have not the same image properties')
[docs]class FeatureImageTypes(enum.Enum):
"""Represents the feature image types."""
ATLAS_COORD = 1
T1w_INTENSITY = 2
T1w_GRADIENT_INTENSITY = 3
T2w_INTENSITY = 4
T2w_GRADIENT_INTENSITY = 5
[docs]def pre_process(id_: str, paths: dict, **kwargs) -> structure.BrainImage:
"""Loads and processes an image.
The processing includes:
- Registration
- Pre-processing
- Feature extraction
Args:
id_ (str): An image identifier.
paths (dict): A dict, where the keys are an image identifier of type structure.BrainImageTypes
and the values are paths to the images.
Returns:
(structure.BrainImage):
"""
print('-' * 10, 'Processing', id_)
# load image
path = paths.pop(id_, '') # the value with key id_ is the root directory of the image
path_to_transform = paths.pop(structure.BrainImageTypes.RegistrationTransform, '')
img = {img_key: sitk.ReadImage(path) for img_key, path in paths.items()}
transform = sitk.ReadTransform(path_to_transform)
img = structure.BrainImage(id_, path, img, transform)
# construct pipeline for brain mask registration
# we need to perform this before the T1w and T2w pipeline because the registered mask is used for skull-stripping
pipeline_brain_mask = fltr.FilterPipeline()
if kwargs.get('registration_pre', False):
pipeline_brain_mask.add_filter(fltr_prep.ImageRegistration())
pipeline_brain_mask.set_param(fltr_prep.ImageRegistrationParameters(atlas_t1, img.transformation, True),
len(pipeline_brain_mask.filters) - 1)
# execute pipeline on the brain mask image
img.images[structure.BrainImageTypes.BrainMask] = pipeline_brain_mask.execute(
img.images[structure.BrainImageTypes.BrainMask])
# construct pipeline for T1w image pre-processing
pipeline_t1 = fltr.FilterPipeline()
if kwargs.get('registration_pre', False):
pipeline_t1.add_filter(fltr_prep.ImageRegistration())
pipeline_t1.set_param(fltr_prep.ImageRegistrationParameters(atlas_t1, img.transformation),
len(pipeline_t1.filters) - 1)
if kwargs.get('skullstrip_pre', False):
pipeline_t1.add_filter(fltr_prep.SkullStripping())
pipeline_t1.set_param(fltr_prep.SkullStrippingParameters(img.images[structure.BrainImageTypes.BrainMask]),
len(pipeline_t1.filters) - 1)
if kwargs.get('normalization_pre', False):
pipeline_t1.add_filter(fltr_prep.ImageNormalization())
# execute pipeline on the T1w image
img.images[structure.BrainImageTypes.T1w] = pipeline_t1.execute(img.images[structure.BrainImageTypes.T1w])
# construct pipeline for T2w image pre-processing
pipeline_t2 = fltr.FilterPipeline()
if kwargs.get('registration_pre', False):
pipeline_t2.add_filter(fltr_prep.ImageRegistration())
pipeline_t2.set_param(fltr_prep.ImageRegistrationParameters(atlas_t2, img.transformation),
len(pipeline_t2.filters) - 1)
if kwargs.get('skullstrip_pre', False):
pipeline_t2.add_filter(fltr_prep.SkullStripping())
pipeline_t2.set_param(fltr_prep.SkullStrippingParameters(img.images[structure.BrainImageTypes.BrainMask]),
len(pipeline_t2.filters) - 1)
if kwargs.get('normalization_pre', False):
pipeline_t2.add_filter(fltr_prep.ImageNormalization())
# execute pipeline on the T2w image
img.images[structure.BrainImageTypes.T2w] = pipeline_t2.execute(img.images[structure.BrainImageTypes.T2w])
# construct pipeline for ground truth image pre-processing
pipeline_gt = fltr.FilterPipeline()
if kwargs.get('registration_pre', False):
pipeline_gt.add_filter(fltr_prep.ImageRegistration())
pipeline_gt.set_param(fltr_prep.ImageRegistrationParameters(atlas_t1, img.transformation, True),
len(pipeline_gt.filters) - 1)
# execute pipeline on the ground truth image
img.images[structure.BrainImageTypes.GroundTruth] = pipeline_gt.execute(
img.images[structure.BrainImageTypes.GroundTruth])
# update image properties to atlas image properties after registration
img.image_properties = conversion.ImageProperties(img.images[structure.BrainImageTypes.T1w])
# extract the features
feature_extractor = FeatureExtractor(img, **kwargs)
img = feature_extractor.execute()
img.feature_images = {} # we free up memory because we only need the img.feature_matrix
# for training of the classifier
return img
[docs]def post_process(img: structure.BrainImage, segmentation: sitk.Image, probability: sitk.Image,
**kwargs) -> sitk.Image:
"""Post-processes a segmentation.
Args:
img (structure.BrainImage): The image.
segmentation (sitk.Image): The segmentation (label image).
probability (sitk.Image): The probabilities images (a vector image).
Returns:
sitk.Image: The post-processed image.
"""
print('-' * 10, 'Post-processing', img.id_)
# construct pipeline
pipeline = fltr.FilterPipeline()
if kwargs.get('simple_post', False):
pipeline.add_filter(fltr_postp.ImagePostProcessing())
if kwargs.get('crf_post', False):
pipeline.add_filter(fltr_postp.DenseCRF())
pipeline.set_param(fltr_postp.DenseCRFParams(img.images[structure.BrainImageTypes.T1w],
img.images[structure.BrainImageTypes.T2w],
probability), len(pipeline.filters) - 1)
return pipeline.execute(segmentation)
[docs]def init_evaluator() -> eval_.Evaluator:
"""Initializes an evaluator.
Returns:
eval.Evaluator: An evaluator.
"""
# initialize metrics
metrics = [metric.DiceCoefficient()]
# todo: add hausdorff distance, 95th percentile (see metric.HausdorffDistance)
warnings.warn('Initialized evaluation with the Dice coefficient. Do you know other suitable metrics?')
# define the labels to evaluate
labels = {1: 'WhiteMatter',
2: 'GreyMatter',
3: 'Hippocampus',
4: 'Amygdala',
5: 'Thalamus'
}
evaluator = eval_.SegmentationEvaluator(metrics, labels)
return evaluator
[docs]def pre_process_batch(data_batch: t.Dict[structure.BrainImageTypes, structure.BrainImage],
pre_process_params: dict = None, multi_process: bool = True) -> t.List[structure.BrainImage]:
"""Loads and pre-processes a batch of images.
The pre-processing includes:
- Registration
- Pre-processing
- Feature extraction
Args:
data_batch (Dict[structure.BrainImageTypes, structure.BrainImage]): Batch of images to be processed.
pre_process_params (dict): Pre-processing parameters.
multi_process (bool): Whether to use the parallel processing on multiple cores or to run sequentially.
Returns:
List[structure.BrainImage]: A list of images.
"""
if pre_process_params is None:
pre_process_params = {}
params_list = list(data_batch.items())
if multi_process:
images = mproc.MultiProcessor.run(pre_process, params_list, pre_process_params, mproc.PreProcessingPickleHelper)
else:
images = [pre_process(id_, path, **pre_process_params) for id_, path in params_list]
return images
[docs]def post_process_batch(brain_images: t.List[structure.BrainImage], segmentations: t.List[sitk.Image],
probabilities: t.List[sitk.Image], post_process_params: dict = None,
multi_process: bool = True) -> t.List[sitk.Image]:
""" Post-processes a batch of images.
Args:
brain_images (List[structure.BrainImageTypes]): Original images that were used for the prediction.
segmentations (List[sitk.Image]): The predicted segmentation.
probabilities (List[sitk.Image]): The prediction probabilities.
post_process_params (dict): Post-processing parameters.
multi_process (bool): Whether to use the parallel processing on multiple cores or to run sequentially.
Returns:
List[sitk.Image]: List of post-processed images
"""
if post_process_params is None:
post_process_params = {}
param_list = zip(brain_images, segmentations, probabilities)
if multi_process:
pp_images = mproc.MultiProcessor.run(post_process, param_list, post_process_params,
mproc.PostProcessingPickleHelper)
else:
pp_images = [post_process(img, seg, prob, **post_process_params) for img, seg, prob in param_list]
return pp_images