Source code for zea.ops.ultrasound

from typing import Tuple

import keras
import numpy as np
from keras import ops

from zea import log
from zea.beamform.beamformer import tof_correction
from zea.display import scan_convert
from zea.func.tensor import (
    apply_along_axis,
    correlate,
    extend_n_dims,
    gaussian_filter,
    reshape_axis,
)
from zea.func.ultrasound import (
    channels_to_complex,
    complex_to_channels,
    demodulate,
    envelope_detect,
    get_band_pass_filter,
    get_low_pass_iq_filter,
    log_compress,
    upmix,
)
from zea.internal.core import (
    DEFAULT_DYNAMIC_RANGE,
    DataTypes,
)
from zea.internal.registry import ops_registry
from zea.ops.base import Filter, Operation
from zea.simulator import simulate_rf
from zea.utils import canonicalize_axis


[docs] @ops_registry("simulate_rf") class Simulate(Operation): """Simulate RF data.""" # Define operation-specific static parameters STATIC_PARAMS = ["n_ax", "apply_lens_correction"] ADD_OUTPUT_KEYS = ["n_ch"] def __init__(self, **kwargs): super().__init__( output_data_type=DataTypes.RAW_DATA, **kwargs, )
[docs] def call( self, scatterer_positions, scatterer_magnitudes, probe_geometry, apply_lens_correction, lens_thickness, lens_sound_speed, sound_speed, n_ax, center_frequency, sampling_frequency, t0_delays, initial_times, element_width, attenuation_coef, tx_apodizations, **kwargs, ): simulate_kwargs = { "probe_geometry": probe_geometry, "apply_lens_correction": apply_lens_correction, "lens_thickness": lens_thickness, "lens_sound_speed": lens_sound_speed, "sound_speed": sound_speed, "n_ax": n_ax, "center_frequency": center_frequency, "sampling_frequency": sampling_frequency, "t0_delays": t0_delays, "initial_times": initial_times, "element_width": element_width, "attenuation_coef": attenuation_coef, "tx_apodizations": tx_apodizations, } if not self.with_batch_dim: simulated_rf = simulate_rf( scatterer_positions=scatterer_positions, scatterer_magnitudes=scatterer_magnitudes, **simulate_kwargs, ) else: simulated_rf = ops.map( lambda inputs: simulate_rf( scatterer_positions=inputs["positions"], scatterer_magnitudes=inputs["magnitudes"], **simulate_kwargs, ), { "positions": scatterer_positions, "magnitudes": scatterer_magnitudes, }, ) return { self.output_key: simulated_rf, "n_ch": 1, # Simulate always returns RF data (so single channel) }
[docs] @ops_registry("tof_correction") class TOFCorrection(Operation): """Time-of-flight correction operation for ultrasound data.""" # Define operation-specific static parameters STATIC_PARAMS = ["f_number", "apply_lens_correction"] def __init__(self, **kwargs): super().__init__( input_data_type=DataTypes.RAW_DATA, output_data_type=DataTypes.ALIGNED_DATA, **kwargs, )
[docs] def call( self, flatgrid, sound_speed, polar_angles, focus_distances, sampling_frequency, f_number, demodulation_frequency, t0_delays, tx_apodizations, initial_times, probe_geometry, t_peak, tx_waveform_indices, transmit_origins, apply_lens_correction=None, lens_thickness=None, lens_sound_speed=None, sos_map=None, sos_grid_x=None, sos_grid_z=None, **kwargs, ): """Perform time-of-flight correction on raw RF data. Args: raw_data (ops.Tensor): Raw RF data to correct flatgrid (ops.Tensor): Grid points at which to evaluate the time-of-flight sound_speed (float): Sound speed in the medium polar_angles (ops.Tensor): Polar angles for scan lines focus_distances (ops.Tensor): Focus distances for scan lines sampling_frequency (float): Sampling frequency f_number (float): F-number for apodization demodulation_frequency (float): Demodulation frequency t0_delays (ops.Tensor): T0 delays tx_apodizations (ops.Tensor): Transmit apodizations initial_times (ops.Tensor): Initial times probe_geometry (ops.Tensor): Probe element positions t_peak (float): Time to peak of the transmit pulse tx_waveform_indices (ops.Tensor): Index of the transmit waveform for each transmit. (All zero if there is only one waveform) transmit_origins (ops.Tensor): Transmit origins of shape (n_tx, 3) apply_lens_correction (bool): Whether to apply lens correction lens_thickness (float): Lens thickness lens_sound_speed (float): Sound speed in the lens sos_map (Tensor): Speed-of-sound map of shape ``(Nz, Nx)`` in m/s. sos_grid_x (Tensor): x-coordinates of ``sos_map`` rows. sos_grid_z (Tensor): z-coordinates of ``sos_map`` columns. Returns: dict: Dictionary containing tof_corrected_data """ raw_data = kwargs[self.key] tof_kwargs = { "flatgrid": flatgrid, "t0_delays": t0_delays, "tx_apodizations": tx_apodizations, "sound_speed": sound_speed, "probe_geometry": probe_geometry, "initial_times": initial_times, "sampling_frequency": sampling_frequency, "demodulation_frequency": demodulation_frequency, "f_number": f_number, "polar_angles": polar_angles, "focus_distances": focus_distances, "t_peak": t_peak, "tx_waveform_indices": tx_waveform_indices, "transmit_origins": transmit_origins, "apply_lens_correction": apply_lens_correction, "lens_thickness": lens_thickness, "lens_sound_speed": lens_sound_speed, "sos_map": sos_map, "sos_grid_x": sos_grid_x, "sos_grid_z": sos_grid_z, } if not self.with_batch_dim: tof_corrected = tof_correction(raw_data, **tof_kwargs) else: tof_corrected = ops.map( lambda data: tof_correction(data, **tof_kwargs), raw_data, ) return {self.output_key: tof_corrected}
[docs] @ops_registry("pfield_weighting") class PfieldWeighting(Operation): """Weighting aligned data with the pressure field.""" def __init__(self, **kwargs): super().__init__( input_data_type=DataTypes.ALIGNED_DATA, output_data_type=DataTypes.ALIGNED_DATA, **kwargs, )
[docs] def call(self, flat_pfield=None, **kwargs): """Weight data with pressure field. Args: flat_pfield (ops.Tensor): Pressure field weight mask of shape (n_pix, n_tx) Returns: dict: Dictionary containing weighted data """ data = kwargs[self.key] # must start with ((batch_size,) n_tx, n_pix, ...) if flat_pfield is None: return {self.output_key: data} # Swap (n_pix, n_tx) to (n_tx, n_pix) flat_pfield = ops.swapaxes(flat_pfield, 0, 1) # Add batch dimension if needed if self.with_batch_dim: pfield_expanded = ops.expand_dims(flat_pfield, axis=0) else: pfield_expanded = flat_pfield append_n_dims = ops.ndim(data) - ops.ndim(pfield_expanded) pfield_expanded = extend_n_dims(pfield_expanded, axis=-1, n_dims=append_n_dims) # Perform element-wise multiplication with the pressure weight mask weighted_data = data * pfield_expanded return {self.output_key: weighted_data}
[docs] @ops_registry("scan_convert") class ScanConvert(Operation): """Scan convert images to cartesian coordinates.""" STATIC_PARAMS = ["fill_value"] ADD_OUTPUT_KEYS = [ "resolution", "x_lim", "y_lim", "z_lim", "rho_range", "theta_range", "phi_range", "d_rho", "d_theta", "d_phi", ] def __init__(self, order=1, **kwargs): """Initialize the ScanConvert operation. Args: order (int, optional): Interpolation order. Defaults to 1. Currently only GPU support for order=1. """ if order > 1: jittable = False log.warning( "GPU support for order > 1 is not available. " + "Disabling jit for ScanConvert." ) else: jittable = True super().__init__( input_data_type=DataTypes.IMAGE, output_data_type=DataTypes.IMAGE_SC, jittable=jittable, **kwargs, ) self.order = order
[docs] def call( self, rho_range=None, theta_range=None, phi_range=None, resolution=None, coordinates=None, fill_value=None, **kwargs, ): """Scan convert images to cartesian coordinates. Args: rho_range (Tuple): Range of the rho axis in the polar coordinate system. Defined in meters. theta_range (Tuple): Range of the theta axis in the polar coordinate system. Defined in radians. phi_range (Tuple): Range of the phi axis in the polar coordinate system. Defined in radians. resolution (float): Resolution of the output image in meters per pixel. if None, the resolution is computed based on the input data. coordinates (Tensor): Coordinates for scan convertion. If None, will be computed based on rho_range, theta_range, phi_range and resolution. If provided, this operation can be jitted. fill_value (float): Value to fill the image with outside the defined region. """ if fill_value is None: fill_value = np.nan data = kwargs[self.key] if self._jit_compile and self.jittable: assert coordinates is not None, ( "coordinates must be provided to jit scan conversion." "You can set ScanConvert(jit_compile=False) to disable jitting." ) data_out, parameters = scan_convert( data, rho_range, theta_range, phi_range, resolution, coordinates, fill_value, self.order, with_batch_dim=self.with_batch_dim, ) return {self.output_key: data_out, **parameters}
[docs] @ops_registry("demodulate") class Demodulate(Operation): """Demodulates the input data to baseband. After this operation, the carrier frequency is removed (0 Hz) and the data is in IQ format stored in two real valued channels.""" ADD_OUTPUT_KEYS = ["center_frequency", "n_ch"] def __init__(self, axis=-3, **kwargs): super().__init__( input_data_type=DataTypes.RAW_DATA, output_data_type=DataTypes.RAW_DATA, jittable=True, **kwargs, ) self.axis = axis
[docs] def call(self, demodulation_frequency=None, sampling_frequency=None, **kwargs): data = kwargs[self.key] # Split the complex signal into two channels iq_data_two_channel = demodulate( data=data, demodulation_frequency=demodulation_frequency, sampling_frequency=sampling_frequency, axis=self.axis, ) return { self.output_key: iq_data_two_channel, "center_frequency": 0.0, "n_ch": 2, }
[docs] @ops_registry("fir_filter") class FirFilter(Operation): """Apply a FIR filter to the input signal using convolution. Looks for the filter taps in the input dictionary using the specified ``filter_key``. """ def __init__( self, axis: int, complex_channels: bool = False, filter_key: str = "fir_filter_taps", **kwargs, ): """ Args: axis (int): Axis along which to apply the filter. Cannot be the batch dimension and not the complex channel axis when ``complex_channels=True``. complex_channels (bool): Whether the last dimension of the input signal represents complex channels (real and imaginary parts). When True, it will convert the signal to ``complex`` dtype before filtering and convert it back to two channels after filtering. filter_key (str): Key in the input dictionary where the FIR filter taps are stored. Default is "fir_filter_taps". """ super().__init__(**kwargs) self._check_axis(axis) self.axis = axis self.complex_channels = complex_channels self.filter_key = filter_key def _check_axis(self, axis, ndim=None): """Check if axis is not the batch dimension.""" if self.with_batch_dim and (axis == 0 or (ndim is not None and axis == -ndim)): raise ValueError("Cannot apply FIR filter along batch dimension.") @property def valid_keys(self): """Get the valid keys for the `call` method.""" return self._valid_keys.union({self.filter_key})
[docs] def call(self, **kwargs): signal = kwargs[self.key] fir_filter_taps = kwargs[self.filter_key] ndim = ops.ndim(signal) self._check_axis(self.axis, ndim) axis = canonicalize_axis(self.axis, ndim) if self.complex_channels: assert axis < ndim - 1, ( "When using complex_channels=True, the complex channels are removed to convert" " to complex numbers before filtering, so axis cannot be the last axis." ) signal = channels_to_complex(signal) def _convolve(signal): """Apply the filter to the signal using correlation.""" return correlate(signal, ops.flip(fir_filter_taps, axis=0), mode="same") filtered_signal = apply_along_axis(_convolve, axis, signal) if self.complex_channels: filtered_signal = complex_to_channels(filtered_signal) return {self.output_key: filtered_signal}
[docs] @ops_registry("low_pass_filter") class LowPassFilterIQ(FirFilter): """Apply a low-pass FIR filter to the demodulated IQ (n_ch=2) input signal using convolution. It is recommended to use :class:`FirFilter` with pre-computed filter taps for jittable operations. The :class:`LowPassFilterIQ` operation itself is not jittable and is provided for convenience only. Uses :func:`get_low_pass_iq_filter` to compute the filter taps. """ def __init__( self, axis: int = -3, num_taps: int = 127, filter_key: str = "low_pass_filter", **kwargs ): """Initialize the LowPassFilterIQ operation. Args: axis (int): Axis along which to apply the filter. Cannot be the batch dimension and cannot be the complex channel axis (the last axis). Default is -3, which is the ``n_ax`` axis for standard ultrasound data layout. num_taps (int): Number of taps in the FIR filter. Default is 127. Odd will result in a type I filter, even in a type II filter. """ if "jittable" in kwargs: raise ValueError("LowPassFilterIQ is not jittable, so jittable must be set to False.") if "complex_channels" in kwargs and not kwargs["complex_channels"]: raise ValueError( "LowPassFilterIQ operates on IQ data, so complex_channels must be True." ) kwargs.pop("jittable", None) kwargs.pop("complex_channels", None) super().__init__( axis=axis, complex_channels=True, filter_key=filter_key, jittable=False, **kwargs, ) self.num_taps = num_taps
[docs] def call(self, bandwidth, sampling_frequency, center_frequency, **kwargs): lpf = get_low_pass_iq_filter( self.num_taps, ops.convert_to_numpy(sampling_frequency).item(), ops.convert_to_numpy(center_frequency).item(), ops.convert_to_numpy(bandwidth).item(), ) kwargs[self.filter_key] = lpf return super().call(**kwargs)
[docs] @ops_registry("band_pass_filter") class BandPassFilter(FirFilter): """Apply a band-pass FIR filter to the real input signal using convolution. The bandwidth parameter in the call method defines the passband centered around ``demodulation_frequency``, with edges at ``demodulation_frequency - bandwidth/2`` and ``demodulation_frequency + bandwidth/2``. So, make sure this is used before demodulation to baseband. This operation is provided for convenience and will recompute the filter weights every time it is called. Alternatively, you can use :class:`FirFilter` with pre-computed filter taps. """ def __init__( self, axis: int = -3, num_taps: int = 127, filter_key: str = "band_pass_filter", **kwargs ): """Initialize the BandPassFilter operation. Args: axis (int): Axis along which to apply the filter. Cannot be the batch dimension. Default is -3, which is the ``n_ax`` axis for standard ultrasound data layout. num_taps (int): Number of taps in the FIR filter. Default is 127. Odd will result in a type I filter, even in a type II filter. """ if "complex_channels" in kwargs and kwargs["complex_channels"]: raise ValueError( "BandPassFilter operates on a real signal, so complex_channels must be False." ) kwargs.pop("complex_channels", None) super().__init__( axis=axis, complex_channels=False, filter_key=filter_key, **kwargs, ) self.num_taps = num_taps
[docs] def call(self, sampling_frequency, demodulation_frequency, bandwidth, **kwargs): """Apply band-pass filter with specified bandwidth. Args: sampling_frequency (float): Sampling frequency in Hz. demodulation_frequency (float): Center frequency in Hz. bandwidth (float): Bandwidth in Hz. The filter will pass frequencies from ``demodulation_frequency - bandwidth/2`` to ``demodulation_frequency + bandwidth/2``. Returns: dict: Dictionary containing filtered signal. """ f1 = demodulation_frequency - bandwidth / 2 f2 = demodulation_frequency + bandwidth / 2 bpf = get_band_pass_filter( self.num_taps, sampling_frequency, f1, f2, validate=not self._jit_compile ) kwargs[self.filter_key] = bpf return super().call(**kwargs)
[docs] @ops_registry("channels_to_complex") class ChannelsToComplex(Operation):
[docs] def call(self, **kwargs): data = kwargs[self.key] output = channels_to_complex(data) return {self.output_key: output}
[docs] @ops_registry("complex_to_channels") class ComplexToChannels(Operation): def __init__(self, axis=-1, **kwargs): super().__init__(**kwargs) self.axis = axis
[docs] def call(self, **kwargs): data = kwargs[self.key] output = complex_to_channels(data, axis=self.axis) return {self.output_key: output}
[docs] @ops_registry("lee_filter") class LeeFilter(Filter): """ The Lee filter is a speckle reduction filter commonly used in synthetic aperture radar (SAR) and ultrasound image processing. It smooths the image while preserving edges and details. This implementation uses Gaussian filter for local statistics and treats channels independently. Lee, J.S. (1980). Digital image enhancement and noise filtering by use of local statistics. IEEE Transactions on Pattern Analysis and Machine Intelligence, (2), 165-168. """ def __init__( self, sigma: float, mode: str = "symmetric", cval: float | None = None, truncate: float = 4.0, axes: Tuple[int] = (-3, -2), **kwargs, ): """ Args: sigma (float or tuple): Standard deviation for Gaussian kernel. The standard deviations of the Gaussian filter are given for each axis as a sequence, or as a single number, in which case it is equal for all axes. mode (str, optional): Padding mode for the input image. Default is 'symmetric'. See [keras docs](https://www.tensorflow.org/api_docs/python/tf/keras/ops/pad) for all options and [tensorflow docs](https://www.tensorflow.org/api_docs/python/tf/pad) for some examples. Note that the naming differs from scipy.ndimage.gaussian_filter! cval (float, optional): Value to fill past edges of input if mode is 'constant'. Default is None. truncate (float, optional): Truncate the filter at this many standard deviations. Default is 4.0. axes (Tuple[int], optional): If None, input is filtered along all axes. Otherwise, input is filtered along the specified axes. When axes is specified, any tuples used for sigma, order, mode and/or radius must match the length of axes. The ith entry in any of these tuples corresponds to the ith entry in axes. Default is (-3, -2), which corresponds to the height and width dimensions of a (..., height, width, channels) tensor. """ super().__init__(**kwargs) self.sigma = sigma self.mode = mode self.cval = cval self.truncate = truncate self.axes = axes
[docs] def call(self, **kwargs): """Apply the Lee filter to the input data. Args: data (ops.Tensor): Input image data of shape (height, width, channels) with optional batch dimension if ``self.with_batch_dim``. """ data = kwargs.pop(self.key) axes = self._resolve_filter_axes(data, self.axes) # Apply Gaussian blur to get local mean img_mean = gaussian_filter( data, self.sigma, mode=self.mode, cval=self.cval, truncate=self.truncate, axes=axes ) # Apply Gaussian blur to squared data to get local squared mean img_sqr_mean = gaussian_filter( data**2, self.sigma, mode=self.mode, cval=self.cval, truncate=self.truncate, axes=axes ) # Calculate local variance img_variance = img_sqr_mean - img_mean**2 # Calculate global variance (per channel) overall_variance = ops.var(data, axis=axes, keepdims=True) # Calculate adaptive weights eps = keras.config.epsilon() img_weights = img_variance / (img_variance + overall_variance + eps) # Apply Lee filter formula img_output = img_mean + img_weights * (data - img_mean) return {self.output_key: img_output}
[docs] @ops_registry("companding") class Companding(Operation): """Companding according to the A- or μ-law algorithm. Invertible compressing operation. Used to compress dynamic range of input data (and subsequently expand). μ-law companding: https://en.wikipedia.org/wiki/%CE%9C-law_algorithm A-law companding: https://en.wikipedia.org/wiki/A-law_algorithm Args: expand (bool, optional): If set to False (default), data is compressed, else expanded. comp_type (str): either `a` or `mu`. mu (float, optional): compression parameter. Defaults to 255. A (float, optional): compression parameter. Defaults to 87.6. """ def __init__(self, expand=False, comp_type="mu", **kwargs): super().__init__(**kwargs) self.expand = expand self.comp_type = comp_type.lower() if self.comp_type not in ["mu", "a"]: raise ValueError("comp_type must be 'mu' or 'a'.") if self.comp_type == "mu": self._compand_func = self._mu_law_expand if self.expand else self._mu_law_compress else: self._compand_func = self._a_law_expand if self.expand else self._a_law_compress @staticmethod def _mu_law_compress(x, mu=255, **kwargs): x = ops.clip(x, -1, 1) return ops.sign(x) * ops.log(1.0 + mu * ops.abs(x)) / ops.log(1.0 + mu) @staticmethod def _mu_law_expand(y, mu=255, **kwargs): y = ops.clip(y, -1, 1) return ops.sign(y) * ((1.0 + mu) ** ops.abs(y) - 1.0) / mu @staticmethod def _a_law_compress(x, A=87.6, **kwargs): x = ops.clip(x, -1, 1) x_sign = ops.sign(x) x_abs = ops.abs(x) A_log = ops.log(A) val1 = x_sign * A * x_abs / (1.0 + A_log) val2 = x_sign * (1.0 + ops.log(A * x_abs)) / (1.0 + A_log) y = ops.where((x_abs >= 0) & (x_abs < (1.0 / A)), val1, val2) return y @staticmethod def _a_law_expand(y, A=87.6, **kwargs): y = ops.clip(y, -1, 1) y_sign = ops.sign(y) y_abs = ops.abs(y) A_log = ops.log(A) val1 = y_sign * y_abs * (1.0 + A_log) / A val2 = y_sign * ops.exp(y_abs * (1.0 + A_log) - 1.0) / A x = ops.where((y_abs >= 0) & (y_abs < (1.0 / (1.0 + A_log))), val1, val2) return x
[docs] def call(self, mu=255, A=87.6, **kwargs): data = kwargs[self.key] mu = ops.cast(mu, data.dtype) A = ops.cast(A, data.dtype) data_out = self._compand_func(data, mu=mu, A=A) return {self.output_key: data_out}
[docs] @ops_registry("downsample") class Downsample(Operation): """Downsample data along a specific axis.""" ADD_OUTPUT_KEYS = ["sampling_frequency", "n_ax"] def __init__(self, factor: int = 1, phase: int = 0, axis: int = -3, **kwargs): super().__init__( **kwargs, ) if factor < 1: raise ValueError("Downsample factor must be >= 1.") if phase < 0 or phase >= factor: raise ValueError("phase must satisfy 0 <= phase < factor.") self.factor = factor self.phase = phase self.axis = axis
[docs] def call(self, sampling_frequency=None, n_ax=None, **kwargs): data = kwargs[self.key] length = ops.shape(data)[self.axis] sample_idx = ops.arange(self.phase, length, self.factor) data_downsampled = ops.take(data, sample_idx, axis=self.axis) output = {self.output_key: data_downsampled} # downsampling also affects the sampling frequency if sampling_frequency is not None: sampling_frequency = sampling_frequency / self.factor output["sampling_frequency"] = sampling_frequency if n_ax is not None: n_ax = n_ax // self.factor output["n_ax"] = n_ax return output
[docs] @ops_registry("anisotropic_diffusion") class AnisotropicDiffusion(Operation): """Speckle Reducing Anisotropic Diffusion (SRAD) filter. Reference: - https://www.researchgate.net/publication/5602035_Speckle_reducing_anisotropic_diffusion - https://nl.mathworks.com/matlabcentral/fileexchange/54044-image-despeckle-filtering-toolbox """
[docs] def call(self, niter=100, lmbda=0.1, rect=None, eps=1e-6, **kwargs): """Anisotropic diffusion filter. Assumes input data is non-negative. Args: niter: Number of iterations. lmbda: Lambda parameter. rect: Rectangle [x1, y1, x2, y2] for homogeneous noise (optional). eps: Small epsilon for stability. Returns: Filtered image (2D tensor or batch of images). """ data = kwargs[self.key] if not self.with_batch_dim: data = ops.expand_dims(data, axis=0) batch_size = ops.shape(data)[0] results = [] for i in range(batch_size): image = data[i] image_out = self._anisotropic_diffusion_single(image, niter, lmbda, rect, eps) results.append(image_out) result = ops.stack(results, axis=0) if not self.with_batch_dim: result = ops.squeeze(result, axis=0) return {self.output_key: result}
def _anisotropic_diffusion_single(self, image, niter, lmbda, rect, eps): """Apply anisotropic diffusion to a single image (2D).""" image = ops.exp(image) M, N = image.shape for _ in range(niter): iN = ops.concatenate([image[1:], ops.zeros((1, N), dtype=image.dtype)], axis=0) iS = ops.concatenate([ops.zeros((1, N), dtype=image.dtype), image[:-1]], axis=0) jW = ops.concatenate([image[:, 1:], ops.zeros((M, 1), dtype=image.dtype)], axis=1) jE = ops.concatenate([ops.zeros((M, 1), dtype=image.dtype), image[:, :-1]], axis=1) if rect is not None: x1, y1, x2, y2 = rect imageuniform = image[x1:x2, y1:y2] q0_squared = (ops.std(imageuniform) / (ops.mean(imageuniform) + eps)) ** 2 dN = iN - image dS = iS - image dW = jW - image dE = jE - image G2 = (dN**2 + dS**2 + dW**2 + dE**2) / (image**2 + eps) L = (dN + dS + dW + dE) / (image + eps) num = (0.5 * G2) - ((1 / 16) * (L**2)) den = (1 + ((1 / 4) * L)) ** 2 q_squared = num / (den + eps) if rect is not None: den = (q_squared - q0_squared) / (q0_squared * (1 + q0_squared) + eps) c = 1.0 / (1 + den) cS = ops.concatenate([ops.zeros((1, N), dtype=image.dtype), c[:-1]], axis=0) cE = ops.concatenate([ops.zeros((M, 1), dtype=image.dtype), c[:, :-1]], axis=1) D = (cS * dS) + (c * dN) + (cE * dE) + (c * dW) image = image + (lmbda / 4) * D result = ops.log(image) return result
[docs] @ops_registry("envelope_detect") class EnvelopeDetect(Operation): """Envelope detection of RF signals.""" def __init__( self, axis=-3, **kwargs, ): super().__init__( input_data_type=DataTypes.BEAMFORMED_DATA, output_data_type=DataTypes.ENVELOPE_DATA, **kwargs, ) self.axis = axis
[docs] def call(self, **kwargs): """ Args: - data (Tensor): The beamformed data of shape (..., grid_size_z, grid_size_x, n_ch). Returns: - envelope_data (Tensor): The envelope detected data of shape (..., grid_size_z, grid_size_x). """ data = kwargs[self.key] data = envelope_detect(data, axis=self.axis) return {self.output_key: data}
[docs] @ops_registry("upmix") class UpMix(Operation): """Upmix IQ data to RF data.""" def __init__( self, upsampling_rate=1, **kwargs, ): super().__init__( **kwargs, ) self.upsampling_rate = upsampling_rate
[docs] def call(self, sampling_frequency=None, demodulation_frequency=None, **kwargs): data = kwargs[self.key] if data.shape[-1] == 1: log.warning("Upmixing is not applicable to RF data.") return {self.output_key: data} elif data.shape[-1] == 2: data = channels_to_complex(data) data = upmix(data, sampling_frequency, demodulation_frequency, self.upsampling_rate) data = ops.expand_dims(data, axis=-1) return {self.output_key: data}
[docs] @ops_registry("log_compress") class LogCompress(Operation): """Logarithmic compression of data.""" def __init__(self, clip: bool = True, **kwargs): """Initialize the LogCompress operation. Args: clip (bool): Whether to clip the output to a dynamic range. Defaults to True. """ super().__init__( input_data_type=DataTypes.ENVELOPE_DATA, output_data_type=DataTypes.IMAGE, **kwargs, ) self.clip = clip
[docs] def call(self, dynamic_range=None, **kwargs): """Apply logarithmic compression to data. Args: dynamic_range (tuple, optional): Dynamic range in dB. Defaults to (-60, 0). Returns: dict: Dictionary containing log-compressed data """ data = kwargs[self.key] if dynamic_range is None: dynamic_range = ops.array(DEFAULT_DYNAMIC_RANGE) dynamic_range = ops.cast(dynamic_range, data.dtype) compressed_data = log_compress(data) if self.clip: compressed_data = ops.clip(compressed_data, dynamic_range[0], dynamic_range[1]) return {self.output_key: compressed_data}
[docs] @ops_registry("reshape_grid") class ReshapeGrid(Operation): """Reshape flat grid data to grid shape.""" def __init__(self, axis=0, **kwargs): super().__init__(**kwargs) self.axis = axis
[docs] def call(self, grid, **kwargs): """ Args: - data (Tensor): The flat grid data of shape (..., n_pix, ...). Returns: - reshaped_data (Tensor): The reshaped data of shape (..., grid.shape, ...). """ data = kwargs[self.key] reshaped_data = reshape_axis(data, grid.shape[:-1], self.axis + int(self.with_batch_dim)) return {self.output_key: reshaped_data}
[docs] @ops_registry("apply_window") class ApplyWindow(Operation): """Apply a window function to the input data along a specific axis. This operation can be used to zero out the end and/or beginning of the signal and apply a window of some size to transition from the zeroed region to the unmodified region. The axis is divided into five regions: [start (zero)] - [size (window)] - [middle (unmodified)] - [size (window)] - [end (zero)] """ STATIC_PARAMS = ["axis", "size", "window_type", "start", "end"] def __init__(self, axis=-3, size=32, start=16, end=0, window_type="hanning", **kwargs): """ Args: axis (int): Axis along which to apply the window. size (int): Size of the window to apply at the start and end regions. start (int): Number of elements to zero at the end. end (int): Number of elements to zero at the end. window_type (str): Type of window to apply. Supported types are "hanning" and "linear". """ super().__init__(**kwargs) self.axis = axis self.size = int(size) self.start = int(start) self.end = int(end) self._check_inputs() self.window_type = window_type self.window = self._get_window(self.window_type, size, "float32") def _check_inputs(self): if self.start < 0: raise ValueError("start must be >= 0.") if self.end < 0: raise ValueError("end must be >= 0.") if self.size < 0: raise ValueError("size must be >= 0.") @staticmethod def _get_window(window_type, size, dtype): if window_type == "hanning": window = ops.hanning(size * 2) elif window_type == "linear": window = ops.concatenate( [ops.linspace(0.0, 1.0, size), ops.linspace(1.0, 0.0, size)], axis=0 ) else: raise ValueError(f"Unsupported window type: {window_type}") return ops.cast(window, dtype)
[docs] def call(self, **kwargs): data = kwargs[self.key] dtype = ops.dtype(data) axis = canonicalize_axis(self.axis, ops.ndim(data)) length = ops.shape(data)[axis] if self.start + self.size * 2 + self.end > length: raise ValueError("start, size, and end are larger than the axis length.") window = ops.cast(self.window, dtype) ones = ops.ones((length,), dtype=dtype) mask = ops.concatenate( [ ops.zeros((self.start,), dtype=dtype), window[: self.size], ones[self.size + self.start : -(self.end + self.size)], window[self.size :], ops.zeros((self.end,), dtype=dtype), ], axis=0, ) shape = [1] * ops.ndim(data) shape[axis] = length mask = ops.reshape(mask, shape) return {self.output_key: data * mask}
[docs] @ops_registry("common_midpoint_phase_error") class CommonMidpointPhaseError(Operation): """Calculates the Common Midpoint Phase Error (CMPE) Computes CMPE between translated transmit and receive apertures with a common midpoint. .. important:: Only works for multistatic datasets, e.g. synthetic aperture data. .. note:: This was directly adapted from the Differentiable Beamforming for Ultrasound Autofocusing (DBUA) paper, see `original paper and code <https://waltersimson.com/dbua/>`_. """ # noqa: E501 def _init_( self, reshape_grid=True, **kwargs, ): super()._init_( input_data_type=None, # DataTypes.IMAGE, because we have an image of the phase map output_data_type=DataTypes.IMAGE, **kwargs, ) self.reshape_grid = reshape_grid
[docs] def create_subapertures(self, data, halfsa, dx): """Create subapertures from the data. Args: data (ops.Tensor): The data to create subapertures from. halfsa (int): Half of the subaperture. dx (float): The spacing between the subapertures. Returns: transmit_subap (ops.Tensor): The transmit subapertures. receive_subap (ops.Tensor): The receive subapertures. """ n_tx, n_pix, n_rx, n_ch = data.shape receive_subaps = ops.zeros((n_rx, n_tx)) for diag in range(-halfsa, halfsa + 1): receive_subaps = receive_subaps + ops.diag(ops.ones((n_rx - abs(diag),)), diag) receive_subaps = receive_subaps[halfsa : receive_subaps.shape[0] - halfsa : dx] transmit_subaps = ops.flip(receive_subaps, axis=0) return transmit_subaps, receive_subaps
[docs] def process_phase_map(self, data, **kwargs): """Create the common midpoint subaperture phase error map. Args: data (ops.Tensor): The data to create the phase error map from. Returns: phase_error_map (ops.Tensor): The phase error map. """ transmit_subaps, receive_subaps = self.create_subapertures(data, 8, 1) complex_data = ops.view_as_complex(data) # [n_tx, n_pix, n_rx, n_ch] -> [n_rtx, n_pix, r_x] complex_data = ops.transpose(complex_data, (2, 0, 1)) # [n_rx, n_tx, n_pix] rx_zero_count = ops.matmul(receive_subaps, ops.cast(complex_data == 0, "int32")) # Mask out subapertures with point outside fov in receive rx_valid = rx_zero_count <= 1 complex_data_rx = ops.matmul(receive_subaps, complex_data) complex_data_rx = ops.where(rx_valid, complex_data_rx, 0) complex_data_rx = ops.transpose(complex_data_rx, (1, 0, 2)) # [n_tx, n_subap_rx, n_pix] tx_zero_count = ops.matmul(transmit_subaps, ops.cast(complex_data_rx == 0, "int32")) # Mask out subapertures with point outside fov in transmit tx_valid = tx_zero_count <= 1 data = ops.matmul(transmit_subaps, complex_data_rx) data = ops.where(tx_valid, data, 0) data = ops.transpose(data, (1, 0, 2)) # [n_subap_tx, n_subap, n_pix] # take diagonals a = data[:-1, :-1] b = data[1:, 1:] valid = (a != 0) & (b != 0) # compute phase difference between cmp neighbours # This only works if the array is regularly spaced xy = a * ops.conj(b) xy = ops.where(valid, xy, 0) dphi = ops.angle(xy) dphi = ops.abs(dphi) dphi = ops.sum(dphi, (0, 1)) / ops.cast(ops.sum(valid, (0, 1)), dphi.dtype) return dphi
[docs] def call( self, **kwargs, ): data = kwargs[self.key] if not self.with_batch_dim: pemap = self.process_phase_map(data) else: pemap = ops.map( lambda d: self.process_phase_map(d), data, ) return {self.output_key: pemap}