{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Task-based transmit beamforming perception-action loop\n", "In this example we will implement a task-based perception-action loop that drives the transmit beamforming pattern towards gaining information about a downstream task variable of interest. We use the left-ventricular inner dimension (LVID), as measured by EchoNetLVH, as our downstream task variable." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tue-bmd/zea/blob/main/docs/source/notebooks/agent/task_based_perception_action_loop.ipynb)\n", " \n", "[![View on GitHub](https://img.shields.io/badge/GitHub-View%20Source-blue?logo=github)](https://github.com/tue-bmd/zea/blob/main/docs/source/notebooks/agent/task_based_perception_action_loop.ipynb)\n", " \n", "[![Hugging Face model](https://img.shields.io/badge/Hugging%20Face-Model-yellow?logo=huggingface)](https://huggingface.co/zeahub/diffusion-echonetlvh)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "‼️ **Important:** This notebook is optimized for **GPU/TPU**. Code execution on a **CPU** may be very slow.\n", "\n", "If you are running in Colab, please enable a hardware accelerator via:\n", "\n", "**Runtime → Change runtime type → Hardware accelerator → GPU/TPU** 🚀." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook steps through a single iteration of the perception-action loop, going from a sparse acquisition $\\rightarrow$ a belief distribution over LVID values $\\rightarrow$ the transmit pattern for the next sparse acquisition. The steps for this loop are illustrated in the following diagram:\n", "\n", "![Downstream Task Diagram](./dst_diagram.png)\n", "\n", "(1) Generate a set of posterior samples from the sparse acquisition using Diffusion Posterior Sampling (DPS).\n", "\n", "(2) Pass each posterior sample $x^{(i)}_t$ through the downstream task model $f$ to produce samples from the downstream task\n", "distribution. \n", "\n", "(3) Compute the Jacobian matrix using each of the posterior samples as inputs. \n", "\n", "(4) Average those Jacobian matrices and multiply them with the pixel-wise variance of the input images to produce the downstream task saliency map. \n", "\n", "(5) Apply K-Greedy Minimization to select $K$ scan lines for the next acquisition." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%%capture\n", "%pip install zea" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup / Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: Using backend 'jax'\n" ] } ], "source": [ "import os\n", "\n", "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"\n", "\n", "import matplotlib.pyplot as plt\n", "from keras import ops\n", "from PIL import Image\n", "import numpy as np\n", "import requests\n", "from io import BytesIO\n", "\n", "from zea import init_device\n", "from zea.visualize import set_mpl_style\n", "from zea.display import scan_convert_2d, inverse_scan_convert_2d\n", "from zea.func import translate\n", "from zea.visualize import plot_image_grid\n", "from zea.io_lib import matplotlib_figure_to_numpy, save_video\n", "\n", "init_device(verbose=False)\n", "set_mpl_style()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "n_prior_steps = 500\n", "n_posterior_steps = 500\n", "n_particles = 4" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load the target data" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# NOTE: this is a synthetic PLAX view image generated by a diffusion model.\n", "url = \"https://raw.githubusercontent.com/tue-bmd/zea/main/docs/source/notebooks/assets/plax.png\"\n", "response = requests.get(url)\n", "img = Image.open(BytesIO(response.content)).convert(\"RGBA\")\n", "\n", "# Split channels\n", "r, g, b, a = img.split()\n", "\n", "# Composite onto a black background (RGB = 0,0,0)\n", "black_bg = Image.new(\"RGBA\", img.size, (0, 0, 0, 255))\n", "img = Image.alpha_composite(black_bg, img)\n", "img = img.convert(\"L\")\n", "img_np = np.asarray(img).astype(np.float32)\n", "img_tensor = ops.convert_to_tensor(img_np)\n", "img_polar = inverse_scan_convert_2d(img_tensor, image_range=(0, 255))\n", "img_polar_np = ops.convert_to_numpy(img_polar)\n", "\n", "# plotting\n", "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))\n", "ax1.imshow(img_np, cmap=\"gray\")\n", "ax1.set_title(\"Cartesian\", fontsize=15)\n", "ax1.axis(\"off\")\n", "ax2.imshow(img_polar_np, cmap=\"gray\")\n", "ax2.set_title(\"Polar\", fontsize=15)\n", "ax2.axis(\"off\")\n", "plt.tight_layout()\n", "plt.savefig(\"cartesian_polar.png\")\n", "plt.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![Cartesian Polar input](./cartesian_polar.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define the downstream task function" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from zea.models.echonetlvh import EchoNetLVH\n", "\n", "# First, load the downstream task model (EchoNetLVH in this case) from zeahub\n", "echonetlvh_model = EchoNetLVH.from_preset(\"echonetlvh\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# We need to precompute the scan conversion coordinates so that the\n", "# scan conversion function is differentiable\n", "from zea.display import compute_scan_convert_2d_coordinates\n", "\n", "# set some parameters for scan conversion\n", "n_rho = 224\n", "n_theta = 224\n", "rho_range = (0, n_rho)\n", "theta_range = (np.deg2rad(-45), np.deg2rad(45))\n", "resolution = 1.0\n", "fill_value = 0.0\n", "image_shape = (n_rho, n_theta)\n", "pre_computed_coords, _ = compute_scan_convert_2d_coordinates(\n", " image_shape,\n", " rho_range,\n", " theta_range,\n", " resolution,\n", ")\n", "\n", "\n", "def lvid_downstream_task(posterior_sample):\n", " \"\"\"\n", " Computes the LVID measurement from a posterior sample generated by the diffusion model.\n", "\n", " Params:\n", " posterior_sample (tensor of shape [H, W]) - should be a single posterior\n", " sample, not a batch, to preserve scalar output for differentiability\n", " using backprop.\n", "\n", " Returns:\n", " lvid_length (float)\n", "\n", " NOTE: we leverage that our downstream task variable is a scalar here to use simple autograd\n", " to compute our jacobian values. For multivariate downstream task variables, you'll need\n", " to compute the full jacobian, or approximate it, using functions like `jax.jvp`.\n", " \"\"\"\n", " assert len(ops.shape(posterior_sample)) == 2 # Should just be [H, W]\n", " # First we need to pre-process the posterior sample from the diffusion model\n", " # so that it becomes a valid input to EchoNetLVH.\n", " posterior_sample_normalized = translate(ops.clip(posterior_sample, -1, 1), (-1, 1), (0, 255))\n", " posterior_sample_sc, _ = scan_convert_2d(\n", " posterior_sample_normalized, coordinates=pre_computed_coords, fill_value=fill_value\n", " )\n", " posterior_sample_sc_resized = ops.image.resize(\n", " posterior_sample_sc[None, ..., None], (224, 224)\n", " ) # model expects batch and channel dims\n", " logits = echonetlvh_model(posterior_sample_sc_resized)\n", " key_points = echonetlvh_model.extract_key_points_as_indices(logits)[0]\n", " lvid_bottom_coords, lvid_top_coords = key_points[1], key_points[2]\n", " lvid_length = ops.squeeze(ops.sqrt(ops.sum((lvid_top_coords - lvid_bottom_coords) ** 2)))\n", " return lvid_length\n", "\n", "\n", "def animate_samples(samples, filename, title, fps=3):\n", " samples = translate(ops.clip(samples, -1, 1), (-1, 1), (0, 255))\n", " # bring frame dim to front\n", " samples = ops.moveaxis(samples, -1, 0)\n", "\n", " frames = []\n", " for i in range(len(samples)):\n", " fig, _ = plot_image_grid(\n", " samples[i],\n", " suptitle=title,\n", " vmin=0,\n", " vmax=255,\n", " cmap=\"gray\",\n", " )\n", " frames.append(matplotlib_figure_to_numpy(fig))\n", " plt.close()\n", " save_video(frames, filename, fps=fps)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Simulate a sparse acquisition\n", "We simulate acquiring a sparse set of focused transmits and beamforming to single columns of lines by simply masking the target image to reveal only certain lines of pixels." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from zea.agent.selection import EquispacedLines\n", "\n", "fully_sampled_image = ops.image.resize(\n", " ops.convert_to_tensor(img_polar_np[None, ..., None]), (256, 256)\n", ")\n", "fully_sampled_image_normalized = translate(\n", " fully_sampled_image, range_from=(0, 255), range_to=(-1, 1)\n", ")\n", "\n", "img_shape = (256, 256)\n", "line_thickness = 1\n", "factor = 32\n", "equispaced_sampler = EquispacedLines(\n", " n_actions=img_shape[1] // line_thickness // factor,\n", " n_possible_actions=img_shape[1] // line_thickness,\n", " img_width=img_shape[1],\n", " img_height=img_shape[0],\n", ")\n", "\n", "_, mask = equispaced_sampler.sample()\n", "mask = ops.expand_dims(mask, axis=-1)\n", "\n", "measurements = ops.where(mask, fully_sampled_image_normalized, 0.0)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots(figsize=(5, 5))\n", "im = ax.imshow(measurements[0, ..., 0], cmap=\"gray\", vmin=-1, vmax=1)\n", "ax.set_title(\"Sparse Measurements\")\n", "ax.axis(\"off\")\n", "plt.tight_layout()\n", "plt.savefig(\"measurements.png\")\n", "plt.close(fig)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![Measurements](./measurements.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Perception step" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First we place the measurements and mask in a 3-frame buffer, since our EchoNetLVH diffusion model is a 3-frame model." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "measurement_buffer = ops.concatenate((ops.zeros((1, *img_shape, 2)), measurements), axis=-1)\n", "mask_buffer = ops.concatenate((ops.zeros((1, *img_shape, 2)), mask), axis=-1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we load (automatically downloaded from the Hugging Face Hub) the diffusion model. We can first quickly sample from the prior $\\mathbf{x} \\sim p(\\mathbf{x})$ to see what kinds of images the model has learned to generate." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m500/500\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m30s\u001b[0m 44ms/step\n", "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: \u001b[32mSuccessfully saved GIF to -> \u001b[33mtask_based_prior_samples.gif\u001b[0m\u001b[0m\n" ] } ], "source": [ "from zea.models.diffusion import DiffusionModel\n", "\n", "diffusion_model = DiffusionModel.from_preset(\"diffusion-echonetlvh-3-frame\")\n", "\n", "prior_samples = diffusion_model.sample(\n", " n_samples=n_particles,\n", " n_steps=n_prior_steps,\n", ")\n", "animate_samples(\n", " prior_samples,\n", " \"./task_based_prior_samples.gif\",\n", " title=r\"Prior samples $x\\sim p(x)$\",\n", " fps=9,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![Task Based Prior Samples](./task_based_prior_samples.gif)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That looks correct, we now proceed with posterior sampling to generate some samples from the Bayesian posterior $\\{\\mathbf{x}_t^{(i)}\\}_{i=0}^{N_p} \\sim p(X_t \\mid \\mathbf{y}_{ \u001b[33mtask_based_posterior_samples.gif\u001b[0m\u001b[0m\n" ] } ], "source": [ "posterior_samples = diffusion_model.posterior_sample(\n", " measurements=measurement_buffer,\n", " mask=mask_buffer,\n", " n_samples=n_particles,\n", " n_steps=n_posterior_steps,\n", " initial_step=0,\n", " omega=10,\n", ")\n", "animate_samples(\n", " posterior_samples[0], # posterior samples has an extra batch dim of length measurements\n", " \"./task_based_posterior_samples.gif\",\n", " title=r\"Posterior samples $x\\sim p(x | y)$\",\n", " fps=9,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![Task Based Posterior Samples](./task_based_posterior_samples.gif)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we use these posterior samples to derive downstream task posterior samples, i.e. beliefs about the value of the LVID. We then compare this to the target LVID measured from the ground-truth in order to see how accurate the agent's beliefs are.\n", "\n", "We also plot this visually, quantifying our downstream uncertainty using Gaussian variance." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Target LVID: 72.5322036743164\n", "Agent's LVID beliefs: [[69.403465 68.92208 75.11672 74.17236 ]]\n" ] } ], "source": [ "# First let's measure the ground truth LVID from the fully-sampled target image\n", "target_lvid = lvid_downstream_task(fully_sampled_image_normalized[0, ..., 0])\n", "\n", "# Then we can pass each posterior sample through the lvid measurement function\n", "lvid_posterior = ops.vectorized_map(\n", " lambda ps: ops.vectorized_map(lambda p: lvid_downstream_task(p[..., -1]), ps), posterior_samples\n", ")\n", "\n", "print(f\"Target LVID: {target_lvid}\")\n", "print(f\"Agent's LVID beliefs: {lvid_posterior}\")" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from scipy.stats import norm\n", "\n", "samples = ops.convert_to_numpy(lvid_posterior).flatten()\n", "\n", "# --- fit Gaussian ---\n", "mu = np.mean(samples)\n", "sigma = np.std(samples, ddof=1)\n", "\n", "# make it a bit taller/thinner if desired\n", "sigma *= 0.8\n", "\n", "# --- x grid for PDF ---\n", "x = np.linspace(mu - 4 * sigma, mu + 4 * sigma, 400)\n", "pdf = norm.pdf(x, mu, sigma)\n", "\n", "fig, ax_density = plt.subplots(figsize=(7, 4))\n", "\n", "# ---- Density axis (left) ----\n", "ax_density.set_ylabel(\"Density\", color=\"white\")\n", "ax_density.set_ylim(0, 0.3)\n", "ax_density.plot(x, pdf, color=\"#FF00FF\", lw=2, label=\"Gaussian PDF\")\n", "ax_density.fill_between(x, pdf, color=\"#FF66FF\", alpha=0.3)\n", "\n", "# Target line\n", "ax_density.axvline(target_lvid, color=\"white\", linestyle=\"--\", lw=1.5, label=\"Target\")\n", "\n", "# ---- Occurrences axis (right) ----\n", "ax_counts = ax_density.twinx()\n", "ax_counts.set_ylabel(\"Occurrences\", color=\"white\")\n", "ax_counts.set_ylim(0, 2.1) # manually cap at 2 occurrences\n", "ax_counts.hist(\n", " samples,\n", " bins=10,\n", " range=(x.min(), x.max()),\n", " color=\"#FF66FF\",\n", " edgecolor=\"white\",\n", " alpha=0.7,\n", " zorder=2,\n", ")\n", "\n", "# Mean/variance text\n", "ax_density.text(\n", " 0.98,\n", " 0.95,\n", " f\"Mean = {mu:.2f}\\nVar = {sigma**2:.2f}\",\n", " ha=\"right\",\n", " va=\"top\",\n", " transform=ax_density.transAxes,\n", " fontsize=12,\n", " color=\"white\",\n", " bbox=dict(boxstyle=\"round,pad=0.3\", fc=\"black\", ec=\"white\", alpha=0.6),\n", ")\n", "\n", "ax_density.set_xlabel(\"LVID measurement\")\n", "ax_density.legend(frameon=False, loc=\"upper left\")\n", "ax_density.grid(alpha=0.2, color=\"white\")\n", "plt.tight_layout()\n", "plt.title(\"LVID target vs beliefs\")\n", "plt.savefig(\"lvid_target_beliefs.png\")\n", "plt.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![LVID target vs beliefs](./lvid_target_beliefs.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Action step\n", "\n", "Finally, we can use our posterior samples and downstream task function to identify the regions of the image space that should be measured in the next sparse acquisition, in order to gain information about the LVID. For this we can use the `TaskBasedLines` function from `zea.agent.selection`, as follows:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from zea.agent.selection import TaskBasedLines\n", "\n", "agent = TaskBasedLines(\n", " n_actions=img_shape[1] // line_thickness // factor,\n", " n_possible_actions=img_shape[1] // line_thickness,\n", " img_width=img_shape[1],\n", " img_height=img_shape[0],\n", " downstream_task_function=lvid_downstream_task,\n", ")\n", "selected_lines_k_hot, mask, pixelwise_contribution_to_var_dst = agent.sample(\n", " posterior_samples[..., -1]\n", ")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# plotting\n", "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))\n", "\n", "# Plot output with measurements\n", "ax1.imshow(pixelwise_contribution_to_var_dst[0] ** 0.5) # rescale by sqrt for visualization\n", "ax1.set_title(\"Saliency Map\", fontsize=15)\n", "ax1.axis(\"off\")\n", "\n", "# Plot input image\n", "ax2.imshow(mask[0])\n", "ax2.set_title(\"Selected Lines\", fontsize=15)\n", "ax2.axis(\"off\")\n", "\n", "plt.savefig(\"task_based_selection.png\")\n", "plt.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![Task based selection](./task_based_selection.png)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.11" } }, "nbformat": 4, "nbformat_minor": 2 }