{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Left ventricle segmentation\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tue-bmd/zea/blob/main/docs/source/notebooks/models/left_ventricle_segmentation_example.ipynb)   [![View on GitHub](https://img.shields.io/badge/GitHub-View%20Source-blue?logo=github)](https://github.com/tue-bmd/zea/blob/main/docs/source/notebooks/models/left_ventricle_segmentation_example.ipynb)\n", "\n", "This notebook demonstrates how to perform left ventricle segmentation in echocardiograms using two different models within the [zea](https://github.com/tue-bmd/zea) framework. We apply both models on the [CAMUS dataset](https://www.creatis.insa-lyon.fr/Challenge/camus/) for demonstration.\n", "\n", "\n", "### 1. EchoNetDynamic\n", "[![Hugging Face model EchoNetDynamic](https://img.shields.io/badge/Hugging%20Face-Model-yellow?logo=huggingface)](https://huggingface.co/zeahub/echonet-dynamic)   [![Paper](https://img.shields.io/badge/Paper-Info-blue)](https://echonet.github.io/dynamic/)\n", "\n", "- Trained on the [EchoNet-Dynamic dataset](https://echonet.github.io/dynamic/).\n", "- Segments the left ventricle in echocardiograms.\n", "\n", "\n", "### 2. Augmented CAMUS Segmentation Model\n", "[![Hugging Face model CAMUS](https://img.shields.io/badge/Hugging%20Face-Model-yellow?logo=huggingface)](https://huggingface.co/gillesvdv/augmented_camus_seg)   [![arXiv](https://img.shields.io/badge/arXiv-Paper-b31b1b.svg)](https://arxiv.org/abs/2502.20100)\n", "\n", "- nnU-Net based model trained on the augmented CAMUS dataset.\n", "- Segments both the left ventricle and myocardium (2 labels).\n", "- State-of-the-art for left ventricle segmentation on CAMUS." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "‼️ **Important:** This notebook is optimized for **GPU/TPU**. Code execution on a **CPU** may be very slow.\n", "\n", "If you are running in Colab, please enable a hardware accelerator via:\n", "\n", "**Runtime → Change runtime type → Hardware accelerator → GPU/TPU** 🚀." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%%capture\n", "%pip install zea\n", "%pip install onnxruntime # needed for the Augmented CAMUS Segmentation Model" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: Using backend 'tensorflow'\n" ] } ], "source": [ "import os\n", "\n", "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n", "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"\n", "\n", "from zea import init_device\n", "import matplotlib.pyplot as plt\n", "from keras import ops\n", "from zea.backend.tensorflow.dataloader import make_dataloader\n", "from zea.visualize import plot_shape_from_mask\n", "from zea.func import translate\n", "from zea.visualize import plot_image_grid, set_mpl_style\n", "\n", "init_device(verbose=False)\n", "set_mpl_style()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load CAMUS Validation Data\n", "\n", "We load a batch of images from the CAMUS validation set. This batch will be used as input for both segmentation models." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: Using pregenerated dataset info file: \u001b[33m/root/.cache/zea/huggingface/datasets/datasets--zeahub--camus-sample/snapshots/617cf91a1267b5ffbcfafe9bebf0813c7cee8493/val/dataset_info.yaml\u001b[0m ...\n", "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: ...for reading file paths in \u001b[33m/root/.cache/zea/huggingface/datasets/datasets--zeahub--camus-sample/snapshots/617cf91a1267b5ffbcfafe9bebf0813c7cee8493/val\u001b[0m\n", "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: Dataset was validated on \u001b[32mSeptember 29, 2025\u001b[0m\n", "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: Remove \u001b[33m/root/.cache/zea/huggingface/datasets/datasets--zeahub--camus-sample/snapshots/617cf91a1267b5ffbcfafe9bebf0813c7cee8493/val/validated.flag\u001b[0m if you want to redo validation.\n", "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: \u001b[38;5;214mWARNING\u001b[0m H5Generator: Not all files have the same shape. This can lead to issues when resizing images later....\n", "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: H5Generator: Shuffled data.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: H5Generator: Shuffled data.\n" ] } ], "source": [ "n_imgs = 16\n", "INFERENCE_SIZE = 256 # Used for both models\n", "val_dataset = make_dataloader(\n", " \"hf://zeahub/camus-sample/val\",\n", " key=\"data/image_sc\",\n", " batch_size=n_imgs,\n", " shuffle=True,\n", " image_range=[-45, 0],\n", " clip_image_range=True,\n", " normalization_range=[-1, 1],\n", " image_size=(INFERENCE_SIZE, INFERENCE_SIZE),\n", " resize_type=\"resize\",\n", " seed=42,\n", ")\n", "\n", "batch = next(iter(val_dataset))\n", "rgb_batch = ops.concatenate([batch, batch, batch], axis=-1) # For EchoNetDynamic" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference with EchoNetDynamic Model\n", "\n", "We first run inference using the EchoNetDynamic model, which expects RGB input images. The model was trained on the EchoNet-Dynamic dataset, but here we apply it to CAMUS data for demonstration." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from zea.models.echonet import EchoNetDynamic\n", "\n", "# Load model\n", "model_echonet = EchoNetDynamic.from_preset(\"echonet-dynamic\")\n", "\n", "# Inference (expects RGB input)\n", "masks_echonet = model_echonet(rgb_batch)\n", "masks_echonet = ops.squeeze(masks_echonet, axis=-1)\n", "masks_echonet = ops.convert_to_numpy(masks_echonet)\n", "\n", "# Visualization\n", "batch_vis = translate(rgb_batch, [-1, 1], [0, 1])\n", "fig, _ = plot_image_grid(batch_vis, vmin=0, vmax=1)\n", "axes = fig.axes[:n_imgs]\n", "for ax, mask in zip(axes, masks_echonet):\n", " plot_shape_from_mask(ax, mask, color=\"red\", alpha=0.4)\n", "\n", "plt.savefig(\"echonet_output.png\", bbox_inches=\"tight\", dpi=100)\n", "plt.close(fig)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**EchoNetDynamic segmentation results:**\n", "\n", "The red overlay shows the predicted left ventricle mask for each image.\n", "\n", "![EchoNet-Dynamic Example Output](./echonet_output.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference with Augmented CAMUS Model\n", "\n", "Now we use the Augmented CAMUS nnU-Net model, which segments both the left ventricle and myocardium (2 labels). The model expects input in NCHW format (batch, channels, height, width)." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from zea.models.lv_segmentation import AugmentedCamusSeg\n", "import numpy as np\n", "\n", "# Load model and weights\n", "model_camus = AugmentedCamusSeg.from_preset(\"augmented_camus_seg\")\n", "\n", "# Prepare input for ONNX (NCHW: batch, channels, height, width)\n", "batch_np = ops.convert_to_numpy(batch)\n", "onnx_input = np.transpose(batch_np, (0, 3, 1, 2))\n", "\n", "# Inference\n", "outputs_camus = model_camus.call(onnx_input)\n", "outputs_camus = np.array(outputs_camus)\n", "# Predicted class = class with the highest score for each pixel\n", "masks_camus = np.argmax(outputs_camus, axis=1) # shape: (batch, H, W)\n", "\n", "# Visualization: show both LV (label 1) and myocardium (label 2)\n", "fig, _ = plot_image_grid(batch_np, vmin=-1, vmax=1)\n", "axes = fig.axes[:n_imgs]\n", "for ax, mask in zip(axes, masks_camus):\n", " # LV: label 1, Myocardium: label 2\n", " plot_shape_from_mask(ax, mask == 1, color=\"red\", alpha=0.3)\n", " plot_shape_from_mask(ax, mask == 2, color=\"blue\", alpha=0.3)\n", "\n", "plt.savefig(\"augmented_camus_seg_output.png\", bbox_inches=\"tight\", dpi=100)\n", "plt.close(fig)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Augmented CAMUS segmentation results:**\n", "\n", "Red: left ventricle mask. Blue: myocardium mask.\n", "\n", "![Augmented CAMUS Segmentation Output](./augmented_camus_seg_output.png)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.11" } }, "nbformat": 4, "nbformat_minor": 2 }