{ "cells": [ { "cell_type": "markdown", "id": "ea2858d8", "metadata": {}, "source": [ "## Hierarchical VAEs for ultrasound image generation and inpainting\n", "\n", "This notebook shows example use-cases of the [Hierarchical Variational Autoencoder (HVAE)](../../_autosummary/zea.models.hvae.rst) as a generative model in the zea framework.\n", "\n", "The model works with Tensorflow, Jax, and Pytorch backend.\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tue-bmd/zea/blob/main/docs/source/notebooks/models/hvae_model_example.ipynb)\n", " \n", "[![View on GitHub](https://img.shields.io/badge/GitHub-View%20Source-blue?logo=github)](https://github.com/tue-bmd/zea/blob/main/docs/source/notebooks/models/hvae_model_example.ipynb)\n", " \n", "[![Hugging Face model](https://img.shields.io/badge/Hugging%20Face-Model-yellow?logo=huggingface)](https://huggingface.co/zeahub/hvae)\n", "\n", "Note that the HVAE is quite large (**24M**) and computationally expensive." ] }, { "cell_type": "markdown", "id": "fe42ee66", "metadata": {}, "source": [ "‼️ **Important:** This notebook is optimized for **GPU/TPU**. Code execution on a **CPU** may be very slow.\n", "\n", "If you are running in Colab, please enable a hardware accelerator via:\n", "\n", "**Runtime → Change runtime type → Hardware accelerator → GPU/TPU** 🚀." ] }, { "cell_type": "code", "execution_count": 1, "id": "9d91f644", "metadata": {}, "outputs": [], "source": [ "%%capture\n", "%pip install zea" ] }, { "cell_type": "markdown", "id": "4f2e7ceb", "metadata": {}, "source": [ "First, we select a Keras backend, select a GPU device and set the select the zea-style for matplotlib." ] }, { "cell_type": "code", "execution_count": 2, "id": "a8245c94", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: Using backend 'jax'\n" ] } ], "source": [ "import os\n", "\n", "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", "\n", "from keras import ops, random\n", "import matplotlib.pyplot as plt\n", "\n", "from zea.models.hvae import HierarchicalVAE\n", "\n", "from zea import init_device\n", "from zea.display import scan_convert_2d\n", "from zea.agent.selection import UniformRandomLines\n", "from zea.visualize import set_mpl_style, plot_image_grid\n", "from zea.backend.tensorflow.dataloader import make_dataloader\n", "\n", "init_device(verbose=False)\n", "set_mpl_style()" ] }, { "cell_type": "code", "execution_count": 3, "id": "452bf45e", "metadata": {}, "outputs": [], "source": [ "data_size = 256\n", "batch_size = 4\n", "inference_fractions = [0.03, 0.12, 0.18, 0.3, 0.5, 0.6, 1.0]\n", "n_samples = 6" ] }, { "cell_type": "markdown", "id": "845af74d", "metadata": {}, "source": [ "### Data loading\n", "\n", "The HVAE model is trained on short 2D ultrasound acquisitions at a resolution of `256x256x3`, where the last dimension denotes 3 subsequent video frames. We can download a batch of data from the `CAMUS` dataset from the [zeahub HuggingFace](https://huggingface.co/zeahub)." ] }, { "cell_type": "code", "execution_count": 4, "id": "7eed61ea", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: Searching \u001b[33m/root/.cache/zea/huggingface/datasets/datasets--zeahub--camus-sample/snapshots/617cf91a1267b5ffbcfafe9bebf0813c7cee8493/val\u001b[0m for ['.hdf5', '.h5'] files...\n", "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: Loading cached result for _find_h5_file_shapes.\n", "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: Dataset was validated on \u001b[32mDecember 17, 2025\u001b[0m\n", "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: Remove \u001b[33m/root/.cache/zea/huggingface/datasets/datasets--zeahub--camus-sample/snapshots/617cf91a1267b5ffbcfafe9bebf0813c7cee8493/val/validated.flag\u001b[0m if you want to redo validation.\n", "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: \u001b[38;5;214mWARNING\u001b[0m H5Generator: Not all files have the same shape. This can lead to issues when resizing images later....\n", "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: H5Generator: Shuffled data.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: H5Generator: Shuffled data.\n" ] } ], "source": [ "val_dataset = make_dataloader(\n", " \"hf://zeahub/camus-sample/val\",\n", " key=\"data/image\",\n", " batch_size=batch_size,\n", " image_range=[-45, 0],\n", " n_frames=3,\n", " clip_image_range=True,\n", " normalization_range=[-1, 1],\n", " image_size=(data_size, data_size),\n", " resize_type=\"resize\",\n", " shuffle=True,\n", " seed=1234,\n", ")\n", "batch = next(iter(val_dataset))" ] }, { "cell_type": "markdown", "id": "0b2f3003", "metadata": {}, "source": [ "### Model initialization\n", "The default HVAE can be loaded with the `\"hvae\"` preset.\n", "\n", "If you would like to inspect the architecture of the neural network, you can uncomment the `model.network.print_model()` line." ] }, { "cell_type": "code", "execution_count": 5, "id": "69890dd7", "metadata": {}, "outputs": [], "source": [ "model = HierarchicalVAE.from_preset(\"hvae\")\n", "# model.network.print_model()" ] }, { "cell_type": "markdown", "id": "dc7435d0", "metadata": {}, "source": [ "We can reconstruct an input image by calling the model, as well as calculate the log_density of the image under the data that the model was trained on.\n", "\n", "The top row shows the input image, and the bottom row the corresponding reconstructed image." ] }, { "cell_type": "code", "execution_count": 6, "id": "c3cebcbd", "metadata": {}, "outputs": [], "source": [ "out, *_ = model.call(batch)\n", "elbo = model.log_density(batch)" ] }, { "cell_type": "code", "execution_count": 7, "id": "114e0ae7", "metadata": {}, "outputs": [], "source": [ "images = (ops.concatenate([out, batch], axis=0) + 1) / 2\n", "images = images[..., -1]\n", "images = scan_convert_2d(images, rho_range=(0, data_size), theta_range=(-0.6, 0.6))[0]\n", "fig, _ = plot_image_grid(images)\n", "fig.suptitle(f\"Average log density of samples in bits/dim: {elbo:.2f}\")\n", "plt.savefig(\"hvae_reconstruction_example.png\", bbox_inches=\"tight\", dpi=100)\n", "plt.close()" ] }, { "cell_type": "markdown", "id": "289ceb60", "metadata": {}, "source": [ "![hvae_reconstruction](hvae_reconstruction_example.png)" ] }, { "cell_type": "markdown", "id": "6c606a83", "metadata": {}, "source": [ "### Partial inference\n", "\n", "Additionally, We can have a look inside the HVAE architecture to see how the image is formed inside the generative model. \n", "\n", "We do this with a custom function [zea.models.hvae.partial_inference](../../_autosummary/zea.models.hvae.rst#zea.models.hvae.HierarchicalVAE.partial_inference) that only passes the input image to a fraction of the layers in the model. The rest of the image is then created with the prior. This lets us look at the way information is propagating in the compressed space.\n", "\n", "Let's try this with the first image (top-left) of the previous example." ] }, { "cell_type": "code", "execution_count": 8, "id": "12271249", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Every experiment generates a (6, 256, 256, 3) tensor.\n" ] } ], "source": [ "# What would happen if we only allow the first 3% of the layers to do inference? or 12%? etc.\n", "# For every experiment, we create 6 samples to show the divergence\n", "# of the path after stopping inference.\n", "\n", "out_partial = []\n", "for frac in inference_fractions:\n", " out_partial.append(\n", " model.partial_inference(\n", " measurements=batch[0:1],\n", " num_layers=frac,\n", " n_samples=n_samples,\n", " )\n", " )\n", "print(f\"Every experiment generates a {out_partial[0][0].shape} tensor.\")" ] }, { "cell_type": "markdown", "id": "3e09fa26", "metadata": {}, "source": [ "Next, we create an image from the samples at every inference fraction, as well as a variance map." ] }, { "cell_type": "code", "execution_count": 9, "id": "6bb08b06", "metadata": {}, "outputs": [], "source": [ "# Plotting\n", "plot_images = []\n", "plot_variances = []\n", "for image_set in out_partial:\n", " # We only visualize the last frame of the 3-frame output of the model.\n", " last_frame = image_set[0, ..., -1]\n", "\n", " # For visualization, we create an image by randomly selecting one of\n", " # the samples for every pixel.\n", " # This gives an idea of the diversity of the samples.\n", " mapping = random.randint(ops.shape(last_frame)[1:3], 0, n_samples)\n", " plot_image = ops.take_along_axis(last_frame, mapping[None, ...], axis=0).squeeze(0)\n", " plot_images.append((plot_image + 1) / 2)\n", "\n", " # Additionally, we show a variance map to indicate model uncertainty at pixel level.\n", " plot_variances.append(ops.var(last_frame, axis=0))\n", "\n", "plot_images = ops.stack(plot_images, axis=0)\n", "plot_variances = ops.stack(plot_variances, axis=0)\n", "\n", "# We convert the images from the polar domain to Cartesian\n", "plot_images = scan_convert_2d(plot_images, rho_range=(0, data_size), theta_range=(-0.6, 0.6))[0]\n", "plot_variances = scan_convert_2d(plot_variances, rho_range=(0, data_size), theta_range=(-0.6, 0.6))[\n", " 0\n", "]\n", "\n", "# Clip variance maps for better visualization\n", "quantiles = ops.quantile(plot_variances, 0.999, axis=(1, 2), keepdims=True)\n", "\n", "fig, axs = plt.subplots(2, len(inference_fractions), figsize=(len(inference_fractions) * 2, 2 * 2))\n", "for i in range(len(inference_fractions)):\n", " axs[0, i].imshow(plot_images[i], cmap=\"gray\")\n", " axs[0, i].set_title(f\"Layers: {int(inference_fractions[i] * model.depth)}/{model.depth}\")\n", " axs[0, i].axis(\"off\")\n", " axs[1, i].imshow(plot_variances[i], vmin=0, vmax=quantiles[i], cmap=\"magma\")\n", " axs[1, i].set_title(\"Variance map\")\n", " axs[1, i].axis(\"off\")\n", "fig.suptitle(\"Increasing levels of inference, moving through layers from left to right\")\n", "plt.savefig(\"hvae_partial_inference_example.png\", bbox_inches=\"tight\", dpi=100)\n", "plt.close()" ] }, { "cell_type": "markdown", "id": "57188fa5", "metadata": {}, "source": [ "![partial_inference](hvae_partial_inference_example.png)" ] }, { "cell_type": "markdown", "id": "fb869f42", "metadata": {}, "source": [ "Here you can nicely see the process in which uncertainty is resolved within the model for this specific input image. Every layer carves out a small part of the prior to better align all the samples.\n", "\n", "> *Note: The edges and the tip of the variance maps stem from artifacts in the EchoNetLVH dataset that the model was trained on.\n", "\n", "---" ] }, { "cell_type": "markdown", "id": "0f570fa9", "metadata": {}, "source": [ "### Posterior sampling\n", "\n", "The HVAE can also be used with the [zea.agent](https://zea.readthedocs.io/en/latest/_autosummary/zea.agent.html), similar to the [Diffusion Model examples](https://zea.readthedocs.io/en/latest/notebooks/agent/agent_example.html).\n", "\n", "In this example, we will subsample lines in the ultrasound images and let the model fill in the gaps." ] }, { "cell_type": "code", "execution_count": 10, "id": "c61c48cd", "metadata": {}, "outputs": [], "source": [ "# We subsample to 24/256 (9.4%) of the columns (lines).\n", "num_lines = 24\n", "agent = UniformRandomLines(\n", " n_actions=num_lines,\n", " n_possible_actions=data_size,\n", " img_width=data_size,\n", " img_height=data_size,\n", ")\n", "# The model predicts three frames at once, so we need three masks.\n", "mask = [agent.sample(batch_size=batch_size)[1] for _ in range(3)]\n", "mask = ops.stack(mask, axis=-1)\n", "\n", "# We set everything outside of the mask to -1.0, as un-observed.\n", "subsampled_batch = ops.where(mask, batch, -1.0)" ] }, { "cell_type": "code", "execution_count": 11, "id": "30113403", "metadata": {}, "outputs": [], "source": [ "# Plotting\n", "plot_images = ops.concatenate([batch, subsampled_batch], axis=0)[..., -1]\n", "plot_images = (plot_images + 1) / 2\n", "plot_images = scan_convert_2d(plot_images, rho_range=(0, data_size), theta_range=(-0.6, 0.6))[0]\n", "plot_image_grid(plot_images)\n", "plt.suptitle(\"CAMUS (top row) and Uniform Random subsampled (bottom row)\")\n", "plt.savefig(\"hvae_subsampled_example.png\", bbox_inches=\"tight\", dpi=100)\n", "plt.close()" ] }, { "cell_type": "markdown", "id": "9e6a08cc", "metadata": {}, "source": [ "![hvae_subsampled](hvae_subsampled_example.png)" ] }, { "cell_type": "markdown", "id": "adcb4313", "metadata": {}, "source": [ "The bottom row of this image is used as a model input, which has as a task to recover the corresponding image in the top row.\n", "\n", "For this task the encoder of the HVAE needs to be retrained, so we load in a specialized version of the weights, `lvh_ur24`:\n", "- `lvh`: The model is trained on EchonetLVH\n", "- `ur`: The model was trained with UniformRandom subsampling\n", "- `24`: The model had 24/256 lines available during training\n", "\n", "An overview of all versions is available on the [zeahub HuggingFace](https://huggingface.co/zeahub), or you can retrain you own with zea!" ] }, { "cell_type": "code", "execution_count": 12, "id": "38691967", "metadata": {}, "outputs": [], "source": [ "model = HierarchicalVAE.from_preset(\"hvae\", version=\"lvh_ur24\")\n", "# Use the subsampled batch as input and generate n_samples posterior samples, just like before.\n", "posterior_samples = model.posterior_sample(subsampled_batch, n_samples=n_samples)" ] }, { "cell_type": "code", "execution_count": 13, "id": "7bece7e6", "metadata": {}, "outputs": [], "source": [ "# Plotting\n", "plot_input = scan_convert_2d(\n", " subsampled_batch[..., -1], rho_range=(0, data_size), theta_range=(-0.6, 0.6), fill_value=-1\n", ")[0]\n", "\n", "# For reconstruction we visualize the last frame of a single sample from the model\n", "plot_sample = (posterior_samples[:, 0, ..., -1] + 1) / 2\n", "plot_sample = scan_convert_2d(plot_sample, rho_range=(0, data_size), theta_range=(-0.6, 0.6))[0]\n", "\n", "plot_variance = ops.var(posterior_samples, axis=1)[..., -1]\n", "plot_variance = scan_convert_2d(plot_variance, rho_range=(0, data_size), theta_range=(-0.6, 0.6))[0]\n", "\n", "fig, axs = plt.subplots(3, batch_size, figsize=(12, 9))\n", "for ax in axs.flatten():\n", " ax.axis(\"off\")\n", "for i in range(batch_size):\n", " axs[0, i].imshow(plot_input[i], cmap=\"gray\")\n", " axs[0, i].set_title(\"Model Input\")\n", " axs[1, i].imshow(plot_sample[i], cmap=\"gray\")\n", " axs[1, i].set_title(\"Reconstruction\")\n", " im = axs[2, i].imshow(plot_variance[i], cmap=\"magma\")\n", " axs[2, i].set_title(\"Variance\")\n", "plt.savefig(\"hvae_posterior_example.png\", bbox_inches=\"tight\", dpi=100)\n", "plt.close()" ] }, { "cell_type": "markdown", "id": "975c3528", "metadata": {}, "source": [ "![hvae_posterior](hvae_posterior_example.png)" ] }, { "cell_type": "markdown", "id": "8276ff2b", "metadata": {}, "source": [ "In these 4 examples, the model is using every layer. The variance maps are no longer related to the model's inner workings, like in the previous example. The variance maps now show the ambiguity of the ground truth given a partial observation!" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.11" } }, "nbformat": 4, "nbformat_minor": 5 }