{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Left ventricular hypertrophy segmentation\n", "\n", "In this example we use the [EchoNetLVH](https://echonet.github.io/lvh/) model to identify key points for measuring left ventricular hypertrophy from parasternal long axis echocardiograms. For more information on the method, see the [original paper](https://jamanetwork.com/journals/jamacardiology/fullarticle/2789370):\n", "- Duffy, G., Cheng, P. P., Yuan, N., He, B., Kwan, A. C., Shun-Shin, M. J., ... & Ouyang, D. (2022). High-throughput precision phenotyping of left ventricular hypertrophy with cardiovascular deep learning. JAMA cardiology, 7(4), 386-395." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tue-bmd/zea/blob/main/docs/source/notebooks/models/left_ventricular_hypertrophy_segmentation_example.ipynb)\n", " \n", "[![View on GitHub](https://img.shields.io/badge/GitHub-View%20Source-blue?logo=github)](https://github.com/tue-bmd/zea/blob/main/docs/source/notebooks/models/left_ventricular_hypertrophy_segmentation_example.ipynb)\n", " \n", "[![Hugging Face model](https://img.shields.io/badge/Hugging%20Face-Model-yellow?logo=huggingface)](https://huggingface.co/zeahub/echonetlvh)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "‼️ **Important:** This notebook is optimized for **GPU/TPU**. Code execution on a **CPU** may be very slow.\n", "\n", "If you are running in Colab, please enable a hardware accelerator via:\n", "\n", "**Runtime → Change runtime type → Hardware accelerator → GPU/TPU** 🚀." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%%capture\n", "%pip install zea" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: Using backend 'jax'\n" ] } ], "source": [ "import os\n", "\n", "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"\n", "\n", "import matplotlib.pyplot as plt\n", "from keras import ops\n", "from PIL import Image\n", "import numpy as np\n", "import requests\n", "from io import BytesIO\n", "\n", "from zea import init_device\n", "from zea.visualize import set_mpl_style\n", "\n", "init_device(verbose=False)\n", "set_mpl_style()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# NOTE: this is a synthetic PLAX view image generated by a diffusion model.\n", "url = \"https://raw.githubusercontent.com/tue-bmd/zea/main/docs/source/notebooks/assets/plax.png\"\n", "response = requests.get(url)\n", "img = Image.open(BytesIO(response.content)).convert(\"RGBA\")\n", "\n", "# Split channels\n", "r, g, b, a = img.split()\n", "\n", "# Composite onto a black background (RGB = 0,0,0)\n", "black_bg = Image.new(\"RGBA\", img.size, (0, 0, 0, 255))\n", "img = Image.alpha_composite(black_bg, img)\n", "\n", "# Convert to grayscale\n", "img = img.convert(\"L\")\n", "\n", "# Convert to numpy\n", "img_np = np.asarray(img).astype(np.float32)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from zea.models.echonetlvh import EchoNetLVH\n", "\n", "# Load model from zeahub\n", "model = EchoNetLVH.from_preset(\"echonetlvh\")\n", "\n", "# Add batch + channel dims\n", "batch = ops.convert_to_tensor(img_np[None, ..., None])\n", "\n", "# apply model to image, producing logits\n", "logits = model(batch)\n", "\n", "# use visualization function to visualize heatmaps and measurement lines on the input image\n", "images_with_measurements = model.visualize_logits(batch, logits)\n", "\n", "# Plotting\n", "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))\n", "\n", "# Plot input image\n", "ax1.imshow(img_np, cmap=\"gray\")\n", "ax1.set_title(\"Input Image\", fontsize=15)\n", "ax1.axis(\"off\")\n", "\n", "# Plot output with measurements\n", "ax2.imshow(images_with_measurements[0])\n", "ax2.set_title(\"Predicted Measurements\", fontsize=15)\n", "ax2.axis(\"off\")\n", "\n", "plt.tight_layout()\n", "plt.savefig(\"echonetlvh_output.png\")\n", "plt.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![EchoNetLVH Example Output](./echonetlvh_output.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Extracting the measurement points\n", "The EchoNetLVH model ouptuts 4 heatmaps -- one for each key point. The heatmaps indicate the probability that each pixel contains the key point. Because of this, we need a function to extract the key point from a given heatmap. There are various ways to do this -- we implement a center-of-mass approach, preserving differentiability.\n", "\n", "What we print below is the set of key points represented as indices with respect to the input image matrix." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Measurement type: [H1, W1] -> [H2, W2]\n", "LVPW: [299 217] -> [287 226]\n", "LVID: [287 226] -> [182 286]\n", "IVS: [182 286] -> [187 281]\n" ] } ], "source": [ "key_points = ops.cast(model.extract_key_points_as_indices(logits)[0], \"int\")\n", "measurement_keys = [\"LVPW\", \"LVID\", \"IVS\"]\n", "print(\"Measurement type: [H1, W1] -> [H2, W2]\")\n", "for i in range(3):\n", " print(f\"{measurement_keys[i]}: {key_points[i]} -> {key_points[i + 1]}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.11" } }, "nbformat": 4, "nbformat_minor": 2 }