{ "cells": [ { "cell_type": "markdown", "id": "0bf4a2d4", "metadata": {}, "source": [ "# Using `zea.Models`: a UNet example\n", "\n", "In this notebook, we demonstrate how to use the `zea.Models` interface with a popular deep learning architecture: the UNet. We'll use a pretrained UNet to inpaint missing regions in ultrasound images, and visualize the results. This workflow can be adapted for other tasks and models in the `zea` toolbox." ] }, { "cell_type": "markdown", "id": "4bd24c43", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tue-bmd/zea/blob/main/docs/source/notebooks/models/unet_example.ipynb)\n", " \n", "[![View on GitHub](https://img.shields.io/badge/GitHub-View%20Source-blue?logo=github)](https://github.com/tue-bmd/zea/blob/main/docs/source/notebooks/models/unet_example.ipynb)\n", " \n", "[![Hugging Face model](https://img.shields.io/badge/Hugging%20Face-Model-yellow?logo=huggingface)](https://huggingface.co/zeahub/unet-echonet-inpainter)" ] }, { "cell_type": "markdown", "id": "dc21d55e", "metadata": {}, "source": [ "‼️ **Important:** This notebook is optimized for **GPU/TPU**. Code execution on a **CPU** may be very slow.\n", "\n", "If you are running in Colab, please enable a hardware accelerator via:\n", "\n", "**Runtime → Change runtime type → Hardware accelerator → GPU/TPU** 🚀." ] }, { "cell_type": "code", "execution_count": 1, "id": "ce8a84d3", "metadata": {}, "outputs": [], "source": [ "%%capture\n", "%pip install zea" ] }, { "cell_type": "code", "execution_count": 2, "id": "947a6cf9", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "8a391c3b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: Using backend 'jax'\n" ] } ], "source": [ "import matplotlib.pyplot as plt\n", "from keras import ops\n", "\n", "from zea import init_device, log\n", "from zea.backend.tensorflow.dataloader import make_dataloader\n", "from zea.models.unet import UNet\n", "from zea.models.lpips import LPIPS\n", "from zea.agent.masks import random_uniform_lines\n", "from zea.visualize import plot_image_grid, set_mpl_style" ] }, { "cell_type": "markdown", "id": "c0bab632", "metadata": {}, "source": [ "We will work with the GPU if available, and initialize using `init_device` to pick the best available device. Also, (optionally), we will set the matplotlib style for plotting." ] }, { "cell_type": "code", "execution_count": 4, "id": "7d0a6ff9", "metadata": {}, "outputs": [], "source": [ "device = init_device(verbose=False)\n", "set_mpl_style()" ] }, { "cell_type": "markdown", "id": "4b593ea4", "metadata": {}, "source": [ "## Load Data\n", "\n", "We load a small batch from the CAMUS validation dataset hosted on Hugging Face Hub." ] }, { "cell_type": "code", "execution_count": 5, "id": "50167c4f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: Using pregenerated dataset info file: \u001b[33m/root/.cache/zea/huggingface/datasets/datasets--zeahub--camus-sample/snapshots/617cf91a1267b5ffbcfafe9bebf0813c7cee8493/val/dataset_info.yaml\u001b[0m ...\n", "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: ...for reading file paths in \u001b[33m/root/.cache/zea/huggingface/datasets/datasets--zeahub--camus-sample/snapshots/617cf91a1267b5ffbcfafe9bebf0813c7cee8493/val\u001b[0m\n", "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: Dataset was validated on \u001b[32mSeptember 24, 2025\u001b[0m\n", "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: Remove \u001b[33m/root/.cache/zea/huggingface/datasets/datasets--zeahub--camus-sample/snapshots/617cf91a1267b5ffbcfafe9bebf0813c7cee8493/val/validated.flag\u001b[0m if you want to redo validation.\n", "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: \u001b[38;5;214mWARNING\u001b[0m H5Generator: Not all files have the same shape. This can lead to issues when resizing images later....\n", "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: H5Generator: Shuffled data.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: H5Generator: Shuffled data.\n" ] } ], "source": [ "n_imgs = 8\n", "\n", "val_dataset = make_dataloader(\n", " \"hf://zeahub/camus-sample/val\",\n", " key=\"data/image\",\n", " batch_size=n_imgs,\n", " shuffle=True,\n", " image_size=[128, 128],\n", " resize_type=\"resize\",\n", " image_range=[-60, 0],\n", " normalization_range=[-1, 1],\n", " seed=42,\n", ")\n", "batch = next(iter(val_dataset))\n", "batch = ops.clip(batch, -1, 1)" ] }, { "cell_type": "markdown", "id": "8236a9f1", "metadata": {}, "source": [ "## Load UNet Model\n", "\n", "We use a pretrained UNet model from `zea` for inpainting." ] }, { "cell_type": "code", "execution_count": 6, "id": "538e4e31", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: Available built-in zea presets for UNet: ['unet-echonet-inpainter']\n" ] } ], "source": [ "presets = list(UNet.presets.keys())\n", "log.info(f\"Available built-in zea presets for UNet: {presets}\")\n", "\n", "model = UNet.from_preset(\"unet-echonet-inpainter\")" ] }, { "cell_type": "markdown", "id": "9a0481c9", "metadata": {}, "source": [ "## Simulate Missing Data\n", "\n", "We simulate missing data by masking out random columns in each image (e.g., 75% missing). This is a common scenario in cognitive ultrasound where some scanlines may be missing (i.e. not acquired) or corrupted." ] }, { "cell_type": "code", "execution_count": 7, "id": "1fe5aefe", "metadata": {}, "outputs": [], "source": [ "n_columns = 128 # batch.shape[2]\n", "mask = random_uniform_lines(n_columns // 4, n_columns, n_imgs)\n", "corrupted = batch * ops.cast(mask[:, None, :, None], batch.dtype)" ] }, { "cell_type": "markdown", "id": "5425a18e", "metadata": {}, "source": [ "## Inpaint with UNet\n", "\n", "We use the UNet to inpaint the missing regions." ] }, { "cell_type": "code", "execution_count": 8, "id": "db6bd7eb", "metadata": {}, "outputs": [], "source": [ "inpainted = model(corrupted)\n", "inpainted = ops.clip(inpainted, -1, 1)" ] }, { "cell_type": "markdown", "id": "e85c2595", "metadata": {}, "source": [ "## Evaluate Perceptual Similarity\n", "\n", "We use the LPIPS metric to evaluate perceptual similarity between the ground truth and inpainted images. For more detailed example of this metric, see [this notebook](../metrics/lpips_example.ipynb)." ] }, { "cell_type": "code", "execution_count": 9, "id": "cefc069e", "metadata": {}, "outputs": [], "source": [ "lpips = LPIPS.from_preset(\"lpips\")\n", "lpips_scores = lpips([inpainted, inpainted])\n", "lpips_scores = ops.convert_to_numpy(lpips_scores)" ] }, { "cell_type": "markdown", "id": "d29f0011", "metadata": {}, "source": [ "## Visualization\n", "\n", "We plot the ground truth, corrupted, inpainted, and error images. The LPIPS score is shown on each inpainted image. Note that this model was trained on the EchoNet-Dynamic dataset, whereas we are testing now on the CAMUS dataset." ] }, { "cell_type": "code", "execution_count": 10, "id": "91d488b9", "metadata": {}, "outputs": [], "source": [ "error = ops.abs(batch - inpainted)\n", "imgs = ops.concatenate([batch, corrupted, inpainted, error], axis=0)\n", "imgs = ops.convert_to_numpy(imgs)\n", "\n", "cmaps = [\"gray\"] * (3 * n_imgs) + [\"viridis\"] * n_imgs\n", "\n", "fig, _ = plot_image_grid(\n", " imgs,\n", " vmin=-1,\n", " vmax=1,\n", " ncols=n_imgs,\n", " remove_axis=False,\n", " cmap=cmaps,\n", " figsize=(n_imgs * 2, 6),\n", ")\n", "\n", "titles = [\"Ground Truth\", \"Corrupted\", \"Inpainted\", \"Error\"]\n", "for i, ax in enumerate(fig.axes[: len(titles) * n_imgs]):\n", " if i % n_imgs == 0:\n", " ax.set_ylabel(titles[i // n_imgs])\n", "\n", "# Show LPIPS score on each inpainted image\n", "for ax, lpips_score in zip(fig.axes[n_imgs * 2 : 3 * n_imgs], lpips_scores):\n", " ax.text(\n", " 0.95,\n", " 0.95,\n", " f\"LPIPS: {float(lpips_score):.4f}\",\n", " ha=\"right\",\n", " va=\"top\",\n", " transform=ax.transAxes,\n", " fontsize=8,\n", " color=\"yellow\",\n", " )\n", "fig.savefig(\"./inpainting_results.png\", dpi=200, bbox_inches=\"tight\")\n", "plt.close(fig)" ] }, { "cell_type": "markdown", "id": "deb77f4b", "metadata": {}, "source": [ "![UNet Inpainting Results](./inpainting_results.png)" ] }, { "cell_type": "markdown", "id": "07263d70", "metadata": {}, "source": [ "You can try other UNet presets or experiment with different masking strategies to explore the capabilities of `zea.Models`!" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.11" } }, "nbformat": 4, "nbformat_minor": 5 }