{ "cells": [ { "cell_type": "markdown", "id": "9272b926", "metadata": {}, "source": [ "# Simulating ultrasound data with ``zea``\n", "\n", "This notebook demonstrates how to simulate ultrasound RF data using the ``zea`` toolbox. We'll define a probe, a scan, and a simple phantom, then use the simulator to generate synthetic RF data. Finally, we'll visualize the results and show how to process the simulated data with a ``zea`` pipeline." ] }, { "cell_type": "markdown", "id": "1096e886", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tue-bmd/zea/blob/main/docs/source/notebooks/data/zea_simulation_example.ipynb)\n", " \n", "[![View on GitHub](https://img.shields.io/badge/GitHub-View%20Source-blue?logo=github)](https://github.com/tue-bmd/zea/blob/main/docs/source/notebooks/data/zea_simulation_example.ipynb)" ] }, { "cell_type": "markdown", "id": "105f92d2", "metadata": {}, "source": [ "‼️ **Important:** This notebook is optimized for **GPU/TPU**. Code execution on a **CPU** may be very slow.\n", "\n", "If you are running in Colab, please enable a hardware accelerator via:\n", "\n", "**Runtime → Change runtime type → Hardware accelerator → GPU/TPU** 🚀." ] }, { "cell_type": "code", "execution_count": 1, "id": "3587c9c7", "metadata": {}, "outputs": [], "source": [ "%%capture\n", "%pip install zea" ] }, { "cell_type": "code", "execution_count": 2, "id": "7b57f8cf", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", "os.environ[\"ZEA_DISABLE_CACHE\"] = \"1\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "a4588db7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: Using backend 'jax'\n" ] } ], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "import zea\n", "from zea import init_device\n", "from zea.simulator import simulate_rf\n", "from zea.probes import Probe\n", "from zea.scan import Scan\n", "from zea.beamform.delays import compute_t0_delays_planewave\n", "from zea.visualize import set_mpl_style\n", "from zea.beamform import phantoms" ] }, { "cell_type": "code", "execution_count": 4, "id": "bd7f70ad", "metadata": {}, "outputs": [], "source": [ "init_device(verbose=False)\n", "set_mpl_style()" ] }, { "cell_type": "markdown", "id": "ec56c8bd", "metadata": {}, "source": [ "Let's define a helper function to plot RF data." ] }, { "cell_type": "code", "execution_count": 5, "id": "cc4d7097", "metadata": {}, "outputs": [], "source": [ "def plot_rf(rf_data, title=\"RF Data\", cmap=\"gray\"):\n", " \"\"\"Plot the first transmit and first channel of the RF data.\"\"\"\n", " plt.figure(figsize=(8, 4))\n", " plt.imshow(\n", " rf_data[0, :, :, 0].T,\n", " aspect=\"auto\",\n", " cmap=cmap,\n", " extent=[0, rf_data.shape[1], 0, rf_data.shape[2]],\n", " )\n", " plt.xlabel(\"Sample (axial)\")\n", " plt.ylabel(\"Element (lateral)\")\n", " plt.title(title)\n", " plt.colorbar(label=\"Amplitude\")\n", " plt.tight_layout()\n", " plt.savefig(\"simulation_plot_rf.png\")\n", " plt.close()" ] }, { "cell_type": "markdown", "id": "312ff168", "metadata": {}, "source": [ "## Define `zea.Probe` and `zea.Scan`\n", "\n", "We'll use a linear probe and a simple planewave scan for this simulation. Let's start with the probe definition." ] }, { "cell_type": "code", "execution_count": 6, "id": "9bd4100a", "metadata": {}, "outputs": [], "source": [ "# Define a linear probe\n", "n_el = 64\n", "aperture = 20e-3\n", "probe_geometry = np.stack(\n", " [np.linspace(-aperture / 2, aperture / 2, n_el), np.zeros(n_el), np.zeros(n_el)], axis=1\n", ")\n", "\n", "probe = Probe(\n", " probe_geometry=probe_geometry,\n", " center_frequency=5e6,\n", " sampling_frequency=20e6,\n", ")" ] }, { "cell_type": "markdown", "id": "36de47b3", "metadata": {}, "source": [ "Now we'll define the necessary parameters for the scan object." ] }, { "cell_type": "code", "execution_count": 7, "id": "ee300172", "metadata": {}, "outputs": [], "source": [ "# Define a planewave scan\n", "n_tx = 3\n", "angles = np.linspace(-5, 5, n_tx) * np.pi / 180\n", "sound_speed = 1540.0\n", "\n", "# Set grid and image size\n", "xlims = (-20e-3, 20e-3)\n", "zlims = (10e-3, 35e-3)\n", "width, height = xlims[1] - xlims[0], zlims[1] - zlims[0]\n", "wavelength = sound_speed / probe.center_frequency\n", "grid_size_x = int(width / (0.5 * wavelength)) + 1\n", "grid_size_z = int(height / (0.5 * wavelength)) + 1\n", "\n", "t0_delays = compute_t0_delays_planewave(\n", " probe_geometry=probe_geometry,\n", " polar_angles=angles,\n", " sound_speed=sound_speed,\n", ")\n", "tx_apodizations = np.ones((n_tx, n_el)) * np.hanning(n_el)[None]" ] }, { "cell_type": "markdown", "id": "79bf80fc", "metadata": {}, "source": [ "Now we can initialize the scan object with the defined parameters." ] }, { "cell_type": "code", "execution_count": 8, "id": "2382cfb2", "metadata": {}, "outputs": [], "source": [ "scan = Scan(\n", " n_tx=n_tx,\n", " n_el=n_el,\n", " center_frequency=probe.center_frequency,\n", " sampling_frequency=probe.sampling_frequency,\n", " probe_geometry=probe_geometry,\n", " t0_delays=t0_delays,\n", " tx_apodizations=tx_apodizations,\n", " element_width=np.linalg.norm(probe_geometry[1] - probe_geometry[0]),\n", " focus_distances=np.ones(n_tx) * np.inf,\n", " polar_angles=angles,\n", " initial_times=np.ones(n_tx) * 1e-6,\n", " n_ax=1024,\n", " xlims=xlims,\n", " zlims=zlims,\n", " grid_size_x=grid_size_x,\n", " grid_size_z=grid_size_z,\n", " lens_sound_speed=1000,\n", " lens_thickness=1e-3,\n", " n_ch=1,\n", " selected_transmits=\"all\",\n", " sound_speed=sound_speed,\n", " apply_lens_correction=False,\n", " attenuation_coef=0.0,\n", ")" ] }, { "cell_type": "markdown", "id": "0a455c64", "metadata": {}, "source": [ "## Simulate RF Data\n", "\n", "Let's simulate some RF data using the `simulate_rf` function and initialize a scatterer phantom." ] }, { "cell_type": "code", "execution_count": 9, "id": "09a3db57", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Simulated RF data shape: (3, 1024, 64, 1)\n" ] } ], "source": [ "# Create the phantom scatterer positions and magnitudes\n", "positions = phantoms.fish()\n", "magnitudes = np.ones(len(positions), dtype=np.float32)\n", "\n", "simulation_args = {\n", " \"scatterer_positions\": positions,\n", " \"scatterer_magnitudes\": magnitudes,\n", " \"probe_geometry\": probe.probe_geometry,\n", " \"apply_lens_correction\": scan.apply_lens_correction,\n", " \"lens_thickness\": scan.lens_thickness,\n", " \"lens_sound_speed\": scan.lens_sound_speed,\n", " \"sound_speed\": scan.sound_speed,\n", " \"n_ax\": scan.n_ax,\n", " \"center_frequency\": probe.center_frequency,\n", " \"sampling_frequency\": probe.sampling_frequency,\n", " \"t0_delays\": scan.t0_delays,\n", " \"initial_times\": scan.initial_times,\n", " \"element_width\": scan.element_width,\n", " \"attenuation_coef\": scan.attenuation_coef,\n", " \"tx_apodizations\": scan.tx_apodizations,\n", "}\n", "\n", "rf_data = simulate_rf(**simulation_args)\n", "print(\"Simulated RF data shape:\", rf_data.shape)" ] }, { "cell_type": "markdown", "id": "24b7c1cd", "metadata": {}, "source": [ "## Visualize RF Data\n", "\n", "Let's plot the simulated RF data for the first transmit." ] }, { "cell_type": "code", "execution_count": 10, "id": "19435040", "metadata": {}, "outputs": [], "source": [ "plot_rf(rf_data, title=\"Simulated RF Data (Tx 0)\")" ] }, { "cell_type": "markdown", "id": "0880802c", "metadata": {}, "source": [ "![simulation_plot](simulation_plot_rf.png)" ] }, { "cell_type": "markdown", "id": "7b177aef", "metadata": {}, "source": [ "## Process simulated data with `zea.Pipeline`\n", "\n", "We can process the simulated RF data using a Zea pipeline to obtain a B-mode image." ] }, { "cell_type": "code", "execution_count": 11, "id": "aef498ed", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: \u001b[38;5;214mWARNING\u001b[0m No azimuth angles provided, using zeros\n", "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: \u001b[38;5;214mWARNING\u001b[0m No transmit origins provided, using zeros\n", "\u001b[1m\u001b[38;5;36mzea\u001b[0m\u001b[0m: \u001b[33mDEBUG\u001b[0m [zea.Pipeline] The following input keys are not used by the pipeline: {'center_frequency', 'xlims', 'n_el', 'zlims'}. Make sure this is intended. This warning will only be shown once.\n" ] } ], "source": [ "pipeline = zea.Pipeline.from_default(enable_pfield=False, with_batch_dim=False, baseband=False)\n", "parameters = pipeline.prepare_parameters(probe, scan, dynamic_range=(-50, 0))\n", "inputs = {pipeline.key: rf_data}\n", "\n", "outputs = pipeline(**inputs, **parameters)\n", "image = outputs[pipeline.output_key]\n", "\n", "image = zea.display.to_8bit(image, dynamic_range=(-50, 0))\n", "\n", "plt.figure()\n", "plt.imshow(\n", " image,\n", " cmap=\"gray\",\n", " extent=[\n", " scan.xlims[0] * 1e3,\n", " scan.xlims[1] * 1e3,\n", " scan.zlims[1] * 1e3,\n", " scan.zlims[0] * 1e3,\n", " ],\n", ")\n", "plt.xlabel(\"X (mm)\")\n", "plt.ylabel(\"Z (mm)\")\n", "plt.title(\"Simulated B-mode Image\")\n", "plt.tight_layout()\n", "plt.savefig(\"simulation_plot_fish.png\")\n", "plt.close()" ] }, { "cell_type": "markdown", "id": "9ad9e08f", "metadata": {}, "source": [ "![simulation_fish](simulation_plot_fish.png)" ] }, { "cell_type": "markdown", "id": "cfc67088", "metadata": {}, "source": [ "That's it! You have now simulated ultrasound RF data and reconstructed a B-mode image using `zea`." ] }, { "cell_type": "markdown", "id": "c6dc06de", "metadata": {}, "source": [ "## Speedup with Just-In-Time compilation (JIT)\n", "\n", "The `simulate_rf` function took quite some time to compute in this example. Larger experiments with more point scatterers or acquisitions can execute very slowly. In this case, it is advised to [JIT-compile](https://docs.jax.dev/en/latest/jit-compilation.html) the `simulate_rf` function. The way you do this depends on which machine learning backend (e.g., JAX, PyTorch, TensorFlow) you are using (see [documentation](../../installation.rst#Backend) for details). Starting with JAX, you can simply wrap the function with `jax.jit` as follows:\n", "\n", "**JAX**" ] }, { "cell_type": "code", "execution_count": 12, "id": "8aa8873d", "metadata": {}, "outputs": [], "source": [ "from jax import jit\n", "\n", "simulate_rf_jit = jit(simulate_rf, static_argnames=[\"apply_lens_correction\", \"n_ax\"])" ] }, { "cell_type": "markdown", "id": "89b0adf8", "metadata": {}, "source": [ "Let's execute and time the JIT versus non-JIT versions of the `simulate_rf` function to see the speedup." ] }, { "cell_type": "code", "execution_count": 13, "id": "83453631", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1mFunction Timing Statistics\u001b[0m\n", "=====================================================================================================\n", "\u001b[36mFunction\u001b[0m \u001b[32mMean\u001b[0m \u001b[32mMedian\u001b[0m \u001b[32mStd Dev\u001b[0m \u001b[33mMin\u001b[0m \u001b[33mMax\u001b[0m \u001b[35mCount\u001b[0m \n", "-----------------------------------------------------------------------------------------------------\n", "\u001b[36msimulate_rf\u001b[0m \u001b[32m0.220066\u001b[0m \u001b[32m0.218923\u001b[0m \u001b[32m0.021850\u001b[0m \u001b[33m0.189399\u001b[0m \u001b[33m0.255911\u001b[0m \u001b[35m30\u001b[0m \n", "\u001b[36msimulate_rf (JIT)\u001b[0m \u001b[32m0.004081\u001b[0m \u001b[32m0.003444\u001b[0m \u001b[32m0.003207\u001b[0m \u001b[33m0.003159\u001b[0m \u001b[33m0.020947\u001b[0m \u001b[35m30\u001b[0m \n" ] } ], "source": [ "from zea.utils import FunctionTimer\n", "\n", "# Warm-up JIT compilation before benchmarking\n", "simulate_rf_jit(**simulation_args)\n", "\n", "timer = FunctionTimer()\n", "timed_rf = timer(simulate_rf, name=\"simulate_rf\")\n", "timed_rf_jit = timer(simulate_rf_jit, name=\"simulate_rf (JIT)\")\n", "\n", "for _ in range(30):\n", " timed_rf_jit(**simulation_args)\n", " timed_rf(**simulation_args)\n", "\n", "timer.print()" ] }, { "cell_type": "markdown", "id": "72b4387a", "metadata": {}, "source": [ "If you are using another backend, a similar approach can be taken:" ] }, { "cell_type": "markdown", "id": "4eba2540", "metadata": {}, "source": [ "**PyTorch**\n", "```python\n", "import torch\n", "simulate_rf_jit = torch.jit.script(simulate_rf)\n", "```\n", "**TensorFlow**\n", "```python\n", "import tensorflow as tf\n", "simulate_rf_jit = tf.function(simulate_rf)\n", "```" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.11" } }, "nbformat": 4, "nbformat_minor": 5 }