Source code for zea.beamform.beamformer

"""Main beamforming functions for ultrasound imaging.

This module implements the core time-of-flight (TOF) correction pipeline used
by the :class:`~zea.ops.ultrasound.TOFCorrection` operation.  It also exposes
the lower-level building blocks (delay computation, f-number masking, phase
rotation, etc.) so that they can be used independently.
"""

import keras
import numpy as np
from keras import ops

from zea.beamform.lens_correction import compute_lens_corrected_travel_times
from zea.func.tensor import vmap
from zea.internal.checks import _check_raw_data


[docs] def fnum_window_fn_rect(normalized_angle): """Rectangular window function for f-number masking. Returns 1 when ``normalized_angle <= 1`` and 0 otherwise. Args: normalized_angle (Tensor): Normalized angle values (0 = on-axis, 1 = edge of the f-number cone). Returns: Tensor: Binary mask with the same shape as *normalized_angle*. """ return ops.where(normalized_angle <= 1.0, 1.0, 0.0)
[docs] def fnum_window_fn_hann(normalized_angle): """Hann window function for f-number masking. Provides a smooth cosine roll-off from 1 at ``normalized_angle = 0`` to 0 at ``normalized_angle = 1``. Args: normalized_angle (Tensor): Normalized angle values. Returns: Tensor: Apodization weights with the same shape as *normalized_angle*. """ return ops.where( normalized_angle <= 1.0, 0.5 * (1 + ops.cos(np.pi * normalized_angle)), 0.0, )
[docs] def fnum_window_fn_tukey(normalized_angle, alpha=0.5): """Tukey window function for f-number masking. A Tukey window is flat in the center and tapers with a cosine lobe near the edges. Setting ``alpha = 0`` produces a rectangular window; ``alpha = 1`` produces a Hann window. Args: normalized_angle (Tensor): Normalized angle values in [0, 1]. alpha (float, optional): Shape parameter controlling the fraction of the window inside the cosine taper. Defaults to ``0.5``. Returns: Tensor: Apodization weights with the same shape as *normalized_angle*. """ normalized_angle = ops.clip(ops.abs(normalized_angle), 0.0, 1.0) beta = 1.0 - alpha return ops.where( normalized_angle < beta, 1.0, ops.where( normalized_angle < 1.0, 0.5 * (1 + ops.cos(np.pi * (normalized_angle - beta) / (ops.abs(alpha) + 1e-6))), 0.0, ), )
[docs] def tof_correction( data, flatgrid, t0_delays, tx_apodizations, sound_speed, probe_geometry, initial_times, sampling_frequency, demodulation_frequency, f_number, polar_angles, focus_distances, t_peak, tx_waveform_indices, transmit_origins, apply_lens_correction=False, lens_thickness=1e-3, lens_sound_speed=1000, fnum_window_fn=fnum_window_fn_rect, sos_map=None, sos_grid_x=None, sos_grid_z=None, ): """Time-of-flight (TOF) correction for ultrasound data on a flat pixel grid. Corrects raw RF or IQ data for differences in propagation time from the transmitter through each pixel and back to every receiving element. Two modes are supported: * **Homogeneous medium** (default) — a constant ``sound_speed`` is used to compute delays analytically via :func:`calculate_delays`. * **Heterogeneous medium** — a spatially-varying speed-of-sound map (``sos_map``) is provided and delays are computed numerically via :func:`calculate_delays_heterogeneous_medium`. .. important:: The heterogeneous mode currently requires **multistatic** acquisitions (``n_tx == n_el``). After delay computation the data is interpolated to the requested pixel positions, masked with the receive f-number aperture, and — for IQ data — phase-rotated to compensate for the demodulation carrier (see :func:`complex_rotate`). Args: data (Tensor): Input RF or IQ data of shape ``(n_tx, n_ax, n_el, n_ch)``. Use ``n_ch=1`` for RF data and ``n_ch=2`` for IQ (in-phase / quadrature). flatgrid (Tensor): Pixel locations ``(x, y, z)`` of shape ``(n_pix, 3)``. t0_delays (Tensor): Per-element transmit fire times, shifted so that the first element fires at *t = 0*, of shape ``(n_tx, n_el)``. tx_apodizations (Tensor): Transmit apodization weights of shape ``(n_tx, n_el)``. sound_speed (float): Speed of sound in m/s. probe_geometry (Tensor): Element positions ``(x, y, z)`` of shape ``(n_el, 3)``. initial_times (Tensor): Per-transmit time offsets of shape ``(n_tx,)``. sampling_frequency (float): Sampling frequency in Hz. demodulation_frequency (float): Demodulation (carrier) frequency in Hz. Only used when ``n_ch=2`` (IQ data). f_number (float): Receive f-number. Set to ``0`` to disable f-number masking. polar_angles (Tensor): Steering angles in radians of shape ``(n_tx,)``. focus_distances (Tensor): Focus distances in meters of shape ``(n_tx,)``. Use ``0`` or ``np.inf`` for plane-wave transmission. t_peak (Tensor): Time of each waveform peak in seconds of shape ``(n_waveforms,)``. tx_waveform_indices (Tensor): Index into ``t_peak`` for each transmit of shape ``(n_tx,)``. transmit_origins (Tensor): Origin of each transmit beam of shape ``(n_tx, 3)``. apply_lens_correction (bool, optional): Apply acoustic-lens correction to the receive travel times (slower but more accurate in the near-field). Defaults to ``False``. lens_thickness (float, optional): Lens thickness in meters. Defaults to ``1e-3``. lens_sound_speed (float, optional): Speed of sound inside the lens in m/s. Defaults to ``1000``. fnum_window_fn (callable, optional): Window function applied to the normalized angle for f-number masking. Receives values in ``[0, 1]`` and should return ``0`` for values ``> 1``. Defaults to :func:`fnum_window_fn_rect`. sos_map (Tensor, optional): 2-D speed-of-sound map of shape ``(Nz, Nx)`` in m/s. When provided, delays are computed numerically (heterogeneous mode, multistatic only). Defaults to ``None``. sos_grid_x (Tensor, optional): x-coordinates of ``sos_map`` columns. sos_grid_z (Tensor, optional): z-coordinates of ``sos_map`` rows. Returns: Tensor: Time-of-flight corrected data of shape ``(n_tx, n_pix, n_el, n_ch)``. """ assert len(data.shape) == 4, ( "The input data should have 4 dimensions, " f"namely n_tx, n_ax, n_el, n_ch, got {len(data.shape)} dimensions: {data.shape}" ) n_tx, n_ax, n_el, _ = ops.shape(data) n_pix = ops.shape(flatgrid)[0] _validate_delay_inputs(data, flatgrid, t0_delays, probe_geometry, tx_apodizations) # ---- Compute delays ------------------------------------------------ # txdel: transmit delay from t=0 to wavefront reaching each pixel # rxdel: receive delay from each pixel back to each element # After this block both have a consistent layout: # txdel: (n_pix, n_tx) rxdel: (n_pix, n_el) if sos_map is None: txdel, rxdel = calculate_delays( flatgrid, t0_delays, tx_apodizations, probe_geometry, initial_times, sampling_frequency, sound_speed, focus_distances, polar_angles, t_peak, tx_waveform_indices, transmit_origins, apply_lens_correction, lens_thickness, lens_sound_speed, ) # calculate_delays returns txdel (n_pix, n_tx), rxdel (n_pix, n_el) else: assert apply_lens_correction is False, ( "Lens correction is not currently supported in heterogeneous SOS mode. " "Either set apply_lens_correction=False or set sos_map=None." ) txdel, rxdel = calculate_delays_heterogeneous_medium( flatgrid, sos_map, sos_grid_x, sos_grid_z, t0_delays, probe_geometry, initial_times, sampling_frequency, t_peak, tx_waveform_indices, ) # calculate_delays_heterogeneous_medium returns txdel (n_tx, n_pix), rxdel (n_el, n_pix) # Transpose both to the shared convention. txdel = ops.moveaxis(txdel, 1, 0) # -> (n_pix, n_tx) rxdel = ops.moveaxis(rxdel, 1, 0) # -> (n_pix, n_el) # ---- F-number mask (receive aperture) ------------------------------ mask = ops.cond( f_number == 0, lambda: ops.ones((n_pix, n_el, 1)), lambda: fnumber_mask(flatgrid, probe_geometry, f_number, fnum_window_fn=fnum_window_fn), ) if sos_map is not None: # Prevent gradients from flowing through the mask when optimising # through the heterogeneous beamformer (e.g. SOS estimation). mask = ops.stop_gradient(mask) # ---- Correct a single transmit (closure) --------------------------- def _correct_single_tx(data_tx, txdel_tx, mask_tx=None): """Apply delay-and-interpolate for one transmit event. Args: data_tx (Tensor): RF/IQ data for one transmit ``(n_ax, n_el, n_ch)``. txdel_tx (Tensor): Transmit delays ``(n_pix, 1)``. mask_tx (Tensor, optional): Per-pixel transmit mask ``(n_pix, 1)``. Returns: Tensor: TOF-corrected data ``(n_pix, n_el, n_ch)``. """ # Total delay per pixel per element: (n_pix, n_el) delays = rxdel + txdel_tx # Interpolate data at the computed delay positions tof_tx = apply_delays(data_tx, delays, clip_min=0, clip_max=n_ax - 1) # Apply f-number mask(s) if mask_tx is not None: tof_tx = tof_tx * mask * mask_tx[:, :, None] else: tof_tx = tof_tx * mask # Phase rotation for IQ data (see complex_rotate docstring) if data_tx.shape[-1] == 2: total_delay_seconds = delays / sampling_frequency theta = 2 * np.pi * demodulation_frequency * total_delay_seconds tof_tx = complex_rotate(tof_tx, theta) return tof_tx # ---- Vectorize over transmits -------------------------------------- # Reshape txdel from (n_pix, n_tx) -> (n_tx, n_pix, 1) for per-tx slicing txdel = ops.moveaxis(txdel, 1, 0)[..., None] if sos_map is None: return vmap(_correct_single_tx)(data, txdel) # Heterogeneous path: apply transmit f-number mask and use gradient # checkpointing to limit memory consumption. mask_tx = ops.moveaxis(mask, 1, 0) _correct_single_tx_ckpt = keras.remat(_correct_single_tx) return vmap(_correct_single_tx_ckpt)(data, txdel, mask_tx)
def _validate_delay_inputs(data, grid, t0_delays, probe_geometry, tx_apodizations): """Validate input shapes common to all delay computation functions. Args: data (Tensor): Input RF or IQ data of shape ``(n_tx, n_ax, n_el, n_ch)``. grid (Tensor): Pixel coordinates of shape ``(n_pix, 3)``. t0_delays (Tensor): Per-element transmit delays of shape ``(n_tx, n_el)``. probe_geometry (Tensor): Element positions of shape ``(n_el, 3)``. Raises: AssertionError: If any array is not 2-D or if any array has an incompatible shape. """ n_tx, n_ax, n_el, n_ch = ops.shape(data) _check_raw_data(data) for arr in [grid, t0_delays, probe_geometry, tx_apodizations]: assert arr.ndim == 2, f"Expected a 2-D array, got shape {arr.shape}." assert ops.shape(grid)[1] == 3, f"Expected grid to have shape (n_pix, 3), got {grid.shape}." assert ops.shape(probe_geometry) == (n_el, 3), ( f"Expected probe_geometry to have shape (n_el, 3), " f"got {probe_geometry.shape} != {(n_el, 3)}." ) assert ops.shape(t0_delays) == (n_tx, n_el), ( f"Expected t0_delays to have shape (n_tx, n_el), got {t0_delays.shape} != {(n_tx, n_el)}." ) assert ops.shape(tx_apodizations) == (n_tx, n_el), ( "Expected tx_apodizations to have shape (n_tx, n_el), " f"got {tx_apodizations.shape} != {(n_tx, n_el)}." )
[docs] def calculate_delays( grid, t0_delays, tx_apodizations, probe_geometry, initial_times, sampling_frequency, sound_speed, focus_distances, polar_angles, t_peak, tx_waveform_indices, transmit_origins, apply_lens_correction=False, lens_thickness=None, lens_sound_speed=None, n_iter=2, ): """Compute transmit and receive delays in samples to every pixel. The total round-trip delay for a pixel is the sum of two components: * **Transmit delay** — time from transmission until the wavefront reaches the pixel. * **Receive delay** — time from the pixel back to each transducer element. Both are returned in **sample units** (i.e. already multiplied by ``sampling_frequency``). Args: grid (Tensor): Pixel coordinates of shape ``(n_pix, 3)``. t0_delays (Tensor): Per-element transmit delays in seconds of shape ``(n_tx, n_el)``, shifted so that the smallest delay is 0. tx_apodizations (Tensor): Transmit apodization weights of shape ``(n_tx, n_el)``. probe_geometry (Tensor): Element positions of shape ``(n_el, 3)``. initial_times (Tensor): Per-transmit time offsets of shape ``(n_tx,)``. sampling_frequency (float): Sampling frequency in Hz. sound_speed (float): Assumed speed of sound in m/s. focus_distances (Tensor): Focus distances of shape ``(n_tx,)``. Use ``0`` or ``np.inf`` for plane-wave transmission. polar_angles (Tensor): Polar steering angles in radians of shape ``(n_tx,)``. t_peak (Tensor): Waveform peak times in seconds of shape ``(n_waveforms,)``. tx_waveform_indices (Tensor): Index into ``t_peak`` for each transmit of shape ``(n_tx,)``. transmit_origins (Tensor): Origin of each transmit beam of shape ``(n_tx, 3)``. apply_lens_correction (bool, optional): Apply acoustic-lens correction (slower but more accurate in the near-field). Defaults to ``False``. lens_thickness (float, optional): Lens thickness in meters. lens_sound_speed (float, optional): Speed of sound in the lens in m/s. n_iter (int, optional): Newton-Raphson iterations for lens correction. Defaults to ``2``. Returns: tuple[Tensor, Tensor]: - **transmit_delays** — of shape ``(n_pix, n_tx)``. - **receive_delays** — of shape ``(n_pix, n_el)``. """ if not apply_lens_correction: # Compute receive distances in meters of shape (n_pix, n_el) rx_distances = distance_Rx(grid, probe_geometry) # Convert distances to delays in seconds rx_delays = rx_distances / sound_speed else: # Compute lens-corrected travel times from each element to each pixel assert lens_thickness is not None, "lens_thickness must be provided for lens correction." assert lens_sound_speed is not None, ( "lens_sound_speed must be provided for lens correction." ) rx_delays = compute_lens_corrected_travel_times( probe_geometry, grid, lens_thickness, lens_sound_speed, sound_speed, n_iter=n_iter, ) # Compute transmit delays tx_delays = vmap(transmit_delays, in_axes=(None, 0, 0, None, 0, 0, 0, None, 0), out_axes=1)( grid, t0_delays, tx_apodizations, rx_delays, focus_distances, polar_angles, initial_times, None, transmit_origins, ) # Add the offset to the transmit peak time tx_delays += ops.take(t_peak, tx_waveform_indices)[None] # TODO: nan to num needed? # tx_delays = ops.nan_to_num(tx_delays, nan=0.0, posinf=0.0, neginf=0.0) # rx_delays = ops.nan_to_num(rx_delays, nan=0.0, posinf=0.0, neginf=0.0) # Convert from seconds to samples tx_delays *= sampling_frequency rx_delays *= sampling_frequency return tx_delays, rx_delays
[docs] def apply_delays(data, delays, clip_min: int = -1, clip_max: int = -1): """Interpolate RF/IQ data at fractional sample positions. Because the exact delay to a pixel will almost never fall on an integer sample index, this function performs **linear interpolation** between the two nearest samples (floor and ceil of each delay value). Args: data (Tensor): RF or IQ data of shape ``(n_ax, n_el, n_ch)``. delays (Tensor): Delays **in samples** of shape ``(n_pix, n_el)``. clip_min (int, optional): Minimum allowed sample index. ``-1`` means no clipping. Defaults to ``-1``. clip_max (int, optional): Maximum allowed sample index. ``-1`` means no clipping. Defaults to ``-1``. Returns: Tensor: Interpolated samples of shape ``(n_pix, n_el, n_ch)``. """ # Add a dummy channel dimension to the delays tensor to ensure it has the # same number of dimensions as the data. The new shape is (n_pix, n_el, 1) delays = delays[..., None] # Get the integer values above and below the exact delay values # Floor to get the integers below # (num_elements, num_pixels, 1) d0 = ops.floor(delays) # Cast to integer to be able to use as indices d0 = ops.cast(d0, "int32") # Add 1 to find the integers above the exact delay values d1 = d0 + 1 # Apply clipping of delays clipping to ensure correct behavior on cpu if clip_min != -1 and clip_max != -1: clip_min = ops.cast(clip_min, d0.dtype) clip_max = ops.cast(clip_max, d0.dtype) d0 = ops.clip(d0, clip_min, clip_max) d1 = ops.clip(d1, clip_min, clip_max) if data.shape[-1] == 2: d0 = ops.concatenate([d0, d0], axis=-1) d1 = ops.concatenate([d1, d1], axis=-1) # Gather pixel values # Here we extract for each transducer element the sample containing the # reflection from each pixel. These are of shape `(n_pix, n_el, n_ch)`. data0 = ops.take_along_axis(data, d0, 0) data1 = ops.take_along_axis(data, d1, 0) # Compute interpolated pixel value d0 = ops.cast(d0, delays.dtype) # Cast to float d1 = ops.cast(d1, delays.dtype) # Cast to float data0 = ops.cast(data0, delays.dtype) # Cast to float data1 = ops.cast(data1, delays.dtype) # Cast to float reflection_samples = (d1 - delays) * data0 + (delays - d0) * data1 return reflection_samples
[docs] def complex_rotate(iq, theta): """Phase-rotate IQ data by angle *theta*. When delaying IQ-demodulated data it is not sufficient to interpolate the I and Q channels independently — the carrier phase shift must be compensated as well. This function applies the rotation: .. math:: I_\\Delta &= I' \\cos\\theta - Q' \\sin\\theta \\\\ Q_\\Delta &= Q' \\cos\\theta + I' \\sin\\theta Args: iq (Tensor): IQ data of shape ``(..., 2)``. theta (Tensor or float): Rotation angle in radians (broadcastable to ``iq[..., 0]``). Returns: Tensor: Rotated IQ data of shape ``(..., 2)``. .. dropdown:: Derivation The IQ data is related to the RF data as follows: .. math:: x(t) &= I(t)\\cos(\\omega_c t) + Q(t)\\cos(\\omega_c t + \\pi/2)\\\\ &= I(t)\\cos(\\omega_c t) - Q(t)\\sin(\\omega_c t) If we want to delay the RF data `x(t)` by `Δt` we can substitute in :math:`t=t+\\Delta t`. We also define :math:`I'(t) = I(t + \\Delta t)`, :math:`Q'(t) = Q(t + \\Delta t)`, and :math:`\\theta=\\omega_c\\Delta t`. This gives us: .. math:: x(t + \\Delta t) &= I'(t) \\cos(\\omega_c (t + \\Delta t)) - Q'(t) \\sin(\\omega_c (t + \\Delta t))\\\\ &= \\overbrace{(I'(t)\\cos(\\theta) - Q'(t)\\sin(\\theta) )}^{I_\\Delta(t)} \\cos(\\omega_c t)\\\\ &- \\overbrace{(Q'(t)\\cos(\\theta) + I'(t)\\sin(\\theta))}^{Q_\\Delta(t)} \\sin(\\omega_c t) This means that to correctly interpolate the IQ data to the new components :math:`I_\\Delta(t)` and :math:`Q_\\Delta(t)`, it is not sufficient to just interpolate the I- and Q-channels independently. We also need to rotate the I- and Q-channels by the angle :math:`\\theta`. This function performs this rotation. """ assert iq.shape[-1] == 2, ( "The last dimension of the input tensor should be 2, " f"got {iq.shape[-1]} dimensions and shape {iq.shape}." ) # Select i and q channels i = iq[..., 0] q = iq[..., 1] # Compute rotated components ir = i * ops.cos(theta) - q * ops.sin(theta) qr = q * ops.cos(theta) + i * ops.sin(theta) # Reintroduce channel dimension ir = ir[..., None] qr = qr[..., None] return ops.concatenate([ir, qr], -1)
[docs] def distance_Rx(grid, probe_geometry): """Euclidean distance from every pixel to every transducer element. Args: grid (Tensor): Pixel positions ``(x, y, z)`` of shape ``(n_pix, 3)``. probe_geometry (Tensor): Element positions ``(x, y, z)`` of shape ``(n_el, 3)``. Returns: Tensor: Distances of shape ``(n_pix, n_el)``. """ # Get norm of distance vector between elements and pixels via broadcasting dist = ops.linalg.norm(grid[:, None, :] - probe_geometry[None, :, :], axis=-1) return dist
[docs] def transmit_delays( grid, t0_delays, tx_apodization, rx_delays, focus_distance, polar_angle, initial_time, azimuth_angle=None, transmit_origin=None, ): """Compute the transmit delay from transmission to each pixel. Uses the **first-arrival** time for pixels before the focus (or virtual source) and the **last-arrival** time for pixels beyond the focus. Args: grid (Tensor): Pixel positions ``(x, y, z)`` of shape ``(n_pix, 3)``. t0_delays (Tensor): Per-element transmit delays in seconds of shape ``(n_el,)``. tx_apodization (Tensor): Transmit apodization weights of shape ``(n_el,)``. rx_delays (Tensor): Travel times in seconds from elements to pixels of shape ``(n_pix, n_el)``. focus_distance (float): Focus distance in meters. Use ``0`` or ``np.inf`` for plane-wave transmission. polar_angle (float): Polar steering angle in radians. initial_time (float): Time offset for this transmit in seconds. azimuth_angle (float, optional): Azimuth steering angle in radians. Defaults to ``None`` (treated as 0). transmit_origin (Tensor, optional): Origin of the transmit beam of shape ``(3,)``. Defaults to ``(0, 0, 0)``. Returns: Tensor: Transmit delays of shape ``(n_pix,)``. """ # Add a large offset for elements that are not used in the transmit to # disqualify them from being the closest element offset = ops.where(tx_apodization == 0, np.inf, 0.0) # Compute total travel time from t=0 to each pixel via each element # rx_delays has shape (n_pix, n_el) # t0_delays has shape (n_el,) total_times = rx_delays + t0_delays[None, :] if azimuth_angle is None: azimuth_angle = ops.zeros_like(polar_angle) # Set origin to (0, 0, 0) if not provided if transmit_origin is None: transmit_origin = ops.zeros(3, dtype=grid.dtype) # Compute the 3D position of the focal point # The beam direction vector beam_direction = ops.stack( [ ops.sin(polar_angle) * ops.cos(azimuth_angle), ops.sin(polar_angle) * ops.sin(azimuth_angle), ops.cos(polar_angle), ] ) # Handle plane wave case where focus_distance is set to zero # We use np.inf to consider the first wavefront arrival for all pixels focus_distance = ops.where(focus_distance == 0.0, np.inf, focus_distance) # Compute focal point position: origin + focus_distance * beam_direction # For negative focus_distance (diverging/virtual source), this is behind the origin focal_point = transmit_origin + focus_distance * beam_direction # shape (3,) # Deal with plane wave case where focus_distance is infinite and beam_direction is zero # (np.inf * 0.0 -> nan) so we convert nan to zero focal_point = ops.where(ops.isnan(focal_point), 0.0, focal_point) # Compute the position of each pixel relative to the focal point pixel_relative_to_focus = grid - focal_point[None, :] # shape (n_pix, 3) # Project onto the beam direction to determine if pixel is before or after focus # Positive projection means pixel is in the direction of beam propagation (beyond focus) # Negative projection means pixel is behind the focus (before focus) projection_along_beam = ops.sum( pixel_relative_to_focus * beam_direction[None, :], axis=-1 ) # shape (n_pix,) # For focused waves (positive focus_distance): # - Use min time for pixels before focus (projection < 0) # - Use max time for pixels beyond focus (projection > 0) # For diverging waves (negative focus_distance, virtual source): # - The sign of focus_distance flips the logic # - Use min time for pixels between transducer and virtual source # - Use max time for pixels beyond transducer is_before_focus = ops.cast(ops.sign(focus_distance), "float32") * projection_along_beam < 0.0 # Compute the effective time of the pixels to the wavefront by computing the # smallest time over all elements (first wavefront arrival) for pixels before # the focus, and the largest time (last wavefront contribution) for pixels # beyond the focus. tx_delay = ops.where( is_before_focus, ops.min(total_times + offset[None, :], axis=-1), ops.max(total_times - offset[None, :], axis=-1), ) # Subtract the initial time offset for this transmit tx_delay = tx_delay - initial_time return tx_delay
[docs] def fnumber_mask(flatgrid, probe_geometry, f_number, fnum_window_fn): """Receive-aperture apodization mask based on the f-number. Computes a per-pixel, per-element mask that suppresses contributions from elements whose angle to a pixel exceeds the acceptance cone defined by the f-number. The transition within the cone is controlled by *fnum_window_fn* (e.g. :func:`fnum_window_fn_rect`, :func:`fnum_window_fn_hann`, :func:`fnum_window_fn_tukey`). Args: flatgrid (Tensor): Flattened pixel grid of shape ``(n_pix, 3)``. probe_geometry (Tensor): Element positions of shape ``(n_el, 3)``. f_number (float): Receive f-number (depth / aperture). A value of ``0`` disables masking. fnum_window_fn (callable): Window function mapping normalized angles in ``[0, 1]`` to weights. Must return ``0`` for inputs ``> 1``. Returns: Tensor: Mask of shape ``(n_pix, n_el, 1)``. """ grid_relative_to_probe = flatgrid[:, None] - probe_geometry[None] grid_relative_to_probe_norm = ops.linalg.norm(grid_relative_to_probe, axis=-1) grid_relative_to_probe_z = grid_relative_to_probe[..., 2] / (grid_relative_to_probe_norm + 1e-6) alpha = ops.arccos(grid_relative_to_probe_z) # The f-number is f_number = z/aperture = 1/(2 * tan(alpha)) # Rearranging gives us alpha = arctan(1/(2 * f_number)) # We can use this to compute the maximum angle alpha that is allowed max_alpha = ops.arctan(1 / (2 * f_number + keras.backend.epsilon())) normalized_angle = alpha / max_alpha mask = fnum_window_fn(normalized_angle) # Add dummy channel dimension mask = mask[..., None] return mask
[docs] def calculate_delays_heterogeneous_medium( grid, sos_map, sos_grid_x, sos_grid_z, t0_delays, probe_geometry, initial_times, sampling_frequency, t_peak, tx_waveform_indices, n_ray_points=100, ): """Compute delays using a spatially-varying speed-of-sound map. Integrates the slowness (1 / speed-of-sound) along straight rays between each element and each pixel to approximate heterogeneous travel times. For the homogeneous (constant speed-of-sound) variant see :func:`calculate_delays`. If you do not have a SOS map, it is recommended to use :func:`calculate_delays`. .. important:: Only valid for **multistatic** acquisitions (``n_tx == n_el``). .. note:: Currently only supports 2D grids, not yet compatible with 3D. Assumes the grid is in the x-z plane and the y dimension is zero. Please use :func:`calculate_delays` for 3D data. .. note:: This function is not compatible with the torch backend. Args: grid (Tensor): Pixel coordinates of shape ``(n_pix, 3)``. sos_map (Tensor): Speed-of-sound map of shape ``(Nz, Nx)`` in m/s. sos_grid_x (Tensor): x-coordinates of ``sos_map`` columns. sos_grid_z (Tensor): z-coordinates of ``sos_map`` rows. t0_delays (Tensor): Transmit delays of shape ``(n_tx, n_el)``, shifted so that the smallest delay is 0. probe_geometry (Tensor): Element positions of shape ``(n_el, 3)``. initial_times (Tensor): Per-transmit time offsets of shape ``(n_tx,)``. sampling_frequency (float): Sampling frequency in Hz. t_peak (Tensor): Waveform peak times of shape ``(n_waveforms,)``. tx_waveform_indices (Tensor): Index into ``t_peak`` for each transmit of shape ``(n_tx,)``. n_ray_points (int, optional): Number of integration points along each element-to-pixel ray. Higher values improve accuracy at the cost of computation time. Defaults to ``100``. Returns: tuple[Tensor, Tensor]: - **tx_delays** — Transmit delays in samples ``(n_tx, n_pix)``. - **rx_delays** — Receive delays in samples ``(n_el, n_pix)``. """ n_tx = ops.shape(t0_delays)[0] n_el = ops.shape(probe_geometry)[0] if keras.backend.backend() == "torch": raise NotImplementedError( "calculate_delays_heterogeneous_medium is not currently " "implemented for the torch backend." ) assert n_tx == n_el, ( "Computing delays with heterogeneous medium (a sos grid was provided) " "requires a multistatic dataset (n_tx == n_el), " f"got n_tx={n_tx}, n_el={n_el}." ) ray_parameters = ops.linspace(1, 0, n_ray_points, endpoint=False)[::-1] slowness_map = 1 / sos_map grid_x = grid[:, 0] grid_z = grid[:, 2] element_x = probe_geometry[:, 0] element_z = probe_geometry[:, 2] def _interpolate_slowness(p, el_x, el_z): xp = p * (grid_x - el_x) + el_x zp = p * (grid_z - el_z) + el_z dx_sos = sos_grid_x[1] - sos_grid_x[0] dz_sos = sos_grid_z[1] - sos_grid_z[0] xit = (xp - sos_grid_x[0]) / dx_sos zit = (zp - sos_grid_z[0]) / dz_sos coords = ops.stack([zit, xit], axis=0) return keras.ops.image.map_coordinates( slowness_map, coords, order=1, fill_mode="nearest", ) # Euclidean distance from each element to each pixel dx = ops.abs(element_x[:, None] - grid_x[None, :]) dz = ops.abs(element_z[:, None] - grid_z[None, :]) ray_lengths = ops.sqrt(dx**2 + dz**2) # Average slowness along each ray via numerical integration slowness = vmap( lambda el_x, el_z: vmap(lambda p: _interpolate_slowness(p, el_x, el_z))(ray_parameters) )(element_x, element_z) valid_mask = ~ops.isnan(slowness) masked_sum = ops.sum(ops.where(valid_mask, slowness, 0.0), axis=1) count = ops.cast(ops.sum(valid_mask, axis=1), masked_sum.dtype) mean_slowness = masked_sum / (count + 1e-9) tof = mean_slowness * ray_lengths rx_delays = tof * sampling_frequency tx_delays = ( tof # The diagonal of t0_delays selects the appropriate transmit delay # for each element (n_tx, n_el) -> (n_tx,) - initial_times[:, None] # can take diag because of the multistatic assumption (n_tx == n_el) + ops.diag(t0_delays)[:, None] + ops.take(t_peak, tx_waveform_indices)[:, None] ) * sampling_frequency return tx_delays, rx_delays