Source code for zea.models.lv_segmentation

"""
nnU-Net segmentation model trained on the augmented CAMUS dataset.

To try this model, simply load one of the available presets:

.. doctest::

    >>> from zea.models.lv_segmentation import AugmentedCamusSeg

    >>> model = AugmentedCamusSeg.from_preset("augmented_camus_seg")

The model segments both the left ventricle and myocardium.

At the time of writing (17 September 2025) and to the best of our knowledge,
it is the state-of-the-art model for left ventricle segmentation on the CAMUS dataset.

.. important::
    This is a ``zea`` implementation of the model.
    For the original paper and code, see `here <https://github.com/GillesVanDeVyver/EchoGAINS>`_.

    Van De Vyver, Gilles, et al.
    "Generative augmentations for improved cardiac ultrasound segmentation using diffusion models."
    *https://arxiv.org/abs/2502.20100*

.. seealso::
    A tutorial notebook where this model is used:
    :doc:`../notebooks/models/left_ventricle_segmentation_example`.

.. note::
    The model is originally a PyTorch model converted to ONNX. To use this model, you must have `onnxruntime` installed. This is required for ONNX model inference.

    You can install it using pip:

    .. code-block:: bash

        pip install onnxruntime

"""  # noqa: E501

from keras import ops

from zea.internal.registry import model_registry
from zea.models.base import BaseModel
from zea.models.preset_utils import get_preset_loader, register_presets
from zea.models.presets import augmented_camus_seg_presets

INFERENCE_SIZE = 256


[docs] @model_registry(name="augmented_camus_seg") class AugmentedCamusSeg(BaseModel): """ nnU-Net based left ventricle and myocardium segmentation model. - Trained on the augmented CAMUS dataset. - This class loads an ONNX model and provides inference for cardiac ultrasound segmentation tasks. """ # noqa: E501
[docs] def call(self, inputs): """ Run inference on the input data using the loaded ONNX model. Args: inputs (np.ndarray): Input image or batch of images for segmentation. Shape: [batch, 1, 256, 256] Range: Any numeric range; normalized internally. Returns: np.ndarray: Segmentation mask(s) for left ventricle and myocardium. Shape: [batch, 3, 256, 256] (logits for background, LV, myocardium) Raises: ValueError: If model weights are not loaded. """ if not hasattr(self, "onnx_sess"): raise ValueError("Model weights not loaded. Please call custom_load_weights() first.") input_name = self.onnx_sess.get_inputs()[0].name output_name = self.onnx_sess.get_outputs()[0].name inputs = ops.convert_to_numpy(inputs).astype("float32") output = self.onnx_sess.run([output_name], {input_name: inputs})[0] return output
[docs] def custom_load_weights(self, preset, **kwargs): """Load the ONNX weights for the segmentation model.""" try: import onnxruntime except ImportError: raise ImportError( "onnxruntime is not installed. Please run " "`pip install onnxruntime` to use this model." ) loader = get_preset_loader(preset) filename = loader.get_file("model.onnx") self.onnx_sess = onnxruntime.InferenceSession(filename)
register_presets(augmented_camus_seg_presets, AugmentedCamusSeg)