import numpy as np
import scipy.signal
from keras import ops
from zea import log
from zea.func.tensor import (
resample,
)
[docs]
def demodulate_not_jitable(
rf_data,
sampling_frequency=None,
demodulation_frequency=None,
bandwidth=None,
filter_coeff=None,
):
"""Demodulates an RF signal to complex base-band (IQ).
Demodulates the radiofrequency (RF) bandpass signals and returns the
Inphase/Quadrature (I/Q) components. IQ is a complex whose real (imaginary)
part contains the in-phase (quadrature) component.
This function operates (i.e. demodulates) on the RF signal over the
(fast-) time axis which is assumed to be the last axis.
Args:
rf_data (ndarray): real valued input array of size [..., n_ax, n_el].
second to last axis is fast-time axis.
sampling_frequency (float): the sampling frequency of the RF signals (in Hz).
Only not necessary when filter_coeff is provided.
demodulation_frequency (float, optional): Modulation frequency (in Hz).
bandwidth (float, optional): Bandwidth of RF signal in % of center
frequency. Defaults to None.
The bandwidth in % is defined by:
B = Bandwidth_in_% = Bandwidth_in_Hz*(100/center_frequency).
The cutoff frequency:
Wn = Bandwidth_in_Hz/sampling_frequency, i.e:
Wn = B*(center_frequency/100)/sampling_frequency.
filter_coeff (list, optional): (b, a), numerator and denominator coefficients
of FIR filter for quadratic band pass filter. All other parameters are ignored
if filter_coeff are provided. Instead the given filter_coeff is directly used.
If not provided, a filter is derived from the other params (sampling_frequency,
center_frequency, bandwidth).
see https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.lfilter.html
Returns:
iq_data (ndarray): complex valued base-band signal.
"""
rf_data = ops.convert_to_numpy(rf_data)
assert np.isreal(rf_data).all(), f"RF must contain real RF signals, got {rf_data.dtype}"
input_shape = rf_data.shape
n_dim = len(input_shape)
if n_dim > 2:
*_, n_ax, n_el = input_shape
else:
n_ax, n_el = input_shape
if filter_coeff is None:
assert sampling_frequency is not None, "provide sampling_frequency when no filter is given."
# Time vector
t = np.arange(n_ax) / sampling_frequency
t0 = 0
t = t + t0
# Estimate center frequency
if demodulation_frequency is None:
# Keep a maximum of 100 randomly selected scanlines
idx = np.arange(n_el)
if n_el > 100:
idx = np.random.permutation(idx)[:100]
# Power Spectrum
P = np.sum(
np.abs(np.fft.fft(np.take(rf_data, idx, axis=-1), axis=-2)) ** 2,
axis=-1,
)
P = P[: n_ax // 2]
# Carrier frequency
idx = np.sum(np.arange(n_ax // 2) * P) / np.sum(P)
demodulation_frequency = idx * sampling_frequency / n_ax
# Normalized cut-off frequency
if bandwidth is None:
Wn = min(2 * demodulation_frequency / sampling_frequency, 0.5)
bandwidth = demodulation_frequency * Wn
else:
assert np.isscalar(bandwidth), "The signal bandwidth (in %) must be a scalar."
assert (bandwidth > 0) & (bandwidth <= 200), (
"The signal bandwidth (in %) must be within the interval of ]0,200]."
)
# bandwidth in Hz
bandwidth = demodulation_frequency * bandwidth / 100
Wn = bandwidth / sampling_frequency
assert (Wn > 0) & (Wn <= 1), (
"The normalized cutoff frequency is not within the interval of (0,1). "
"Check the input parameters!"
)
# Down-mixing of the RF signals
carrier = np.exp(-1j * 2 * np.pi * demodulation_frequency * t)
# add the singleton dimensions
carrier = np.reshape(carrier, (*[1] * (n_dim - 2), n_ax, 1))
iq_data = rf_data * carrier
# Low-pass filter
N = 5
b, a = scipy.signal.butter(N, Wn, "low")
# factor 2: to preserve the envelope amplitude
iq_data = scipy.signal.filtfilt(b, a, iq_data, axis=-2) * 2
# Display a warning message if harmful aliasing is suspected
# the RF signal is undersampled
if sampling_frequency < (2 * demodulation_frequency + bandwidth):
# lower and higher frequencies of the bandpass signal
fL = demodulation_frequency - bandwidth / 2
fH = demodulation_frequency + bandwidth / 2
n = fH // (fH - fL)
harmless_aliasing = any(
(2 * fH / np.arange(1, n) <= sampling_frequency)
& (sampling_frequency <= 2 * fL / np.arange(1, n))
)
if not harmless_aliasing:
log.warning(
"rf2iq:harmful_aliasing Harmful aliasing is present: the aliases"
" are not mutually exclusive!"
)
else:
b, a = filter_coeff
iq_data = scipy.signal.lfilter(b, a, rf_data, axis=-2) * 2
return iq_data
[docs]
def upmix(iq_data, sampling_frequency, demodulation_frequency, upsampling_rate=6):
"""Upsamples and upmixes complex base-band signals (IQ) to RF.
Args:
iq_data (ndarray): complex valued input array of size [..., n_ax, n_el]. second
to last axis is fast-time axis.
sampling_frequency (float): the sampling frequency of the input IQ signal (in Hz).
resulting sampling_frequency of RF data is upsampling_rate times higher.
demodulation_frequency (float, optional): modulation frequency (in Hz).
Returns:
rf_data (ndarray): output real valued rf data.
"""
assert iq_data.dtype in [
"complex64",
"complex128",
], "IQ must contain all complex signals."
input_shape = iq_data.shape
n_dim = len(input_shape)
if n_dim > 2:
*_, n_ax, _ = input_shape
else:
n_ax, _ = input_shape
# Time vector
n_ax_up = n_ax * upsampling_rate
sampling_frequency_up = sampling_frequency * upsampling_rate
t = ops.arange(n_ax_up, dtype="float32") / sampling_frequency_up
t0 = 0
t = t + t0
iq_data_upsampled = resample(
iq_data,
n_samples=n_ax_up,
axis=-2,
order=1,
)
# Up-mixing of the IQ signals
t = ops.cast(t, dtype="complex64")
demodulation_frequency = ops.cast(demodulation_frequency, dtype="complex64")
carrier = ops.exp(1j * 2 * np.pi * demodulation_frequency * t)
carrier = ops.reshape(carrier, (*[1] * (n_dim - 2), n_ax_up, 1))
rf_data = iq_data_upsampled * carrier
rf_data = ops.real(rf_data) * ops.sqrt(2)
return ops.cast(rf_data, "float32")
def _sinc(x):
"""Return the normalized sinc function. Equivalent to np.sinc(x)."""
y = np.pi * ops.where(x == 0, 1.0e-20, x)
return ops.sin(y) / y
[docs]
def get_band_pass_filter(num_taps, sampling_frequency, f1, f2, validate=True):
"""Band pass filter
Compatible with ``jax.jit`` when ``numtaps`` is static. Based on ``scipy.signal.firwin`` with
hamming window.
Args:
num_taps (int): number of taps in filter.
sampling_frequency (float): sample frequency in Hz.
f1 (float): cutoff frequency in Hz of left band edge.
f2 (float): cutoff frequency in Hz of right band edge.
validate (bool, optional): whether to validate the cutoff frequencies. Defaults to True.
Returns:
ndarray: band pass filter
"""
sampling_frequency = ops.cast(sampling_frequency, "float32")
f1 = ops.cast(f1, "float32")
f2 = ops.cast(f2, "float32")
nyq = 0.5 * sampling_frequency
f1 = f1 / nyq
f2 = f2 / nyq
if validate:
if f1 <= 0 or f2 >= 1:
raise ValueError(
f"Invalid cutoff frequency: frequencies must be greater than 0 and less than fs/2. "
f"Got f1={f1 * nyq} Hz, f2={f2 * nyq} Hz."
)
if f1 >= f2:
raise ValueError(
f"Invalid cutoff frequencies: the frequencies must be strictly increasing. "
f"Got f1={f1 * nyq} Hz, f2={f2 * nyq} Hz."
)
# Build up the coefficients.
alpha = 0.5 * (num_taps - 1)
m = ops.arange(0, num_taps, dtype="float32") - alpha
h = f2 * _sinc(f2 * m) - f1 * _sinc(f1 * m)
# Get and apply the window function.
win = np.hamming(num_taps)
win = ops.convert_to_tensor(win, dtype=h.dtype)
h *= win
# Use center frequency for scaling: 0 for lowpass, 1 (Nyquist) for highpass, or band center
scale_frequency = ops.where(f1 == 0, 0.0, ops.where(f2 == 1, 1.0, 0.5 * (f1 + f2)))
c = ops.cos(np.pi * m * scale_frequency)
s = ops.sum(h * c)
h /= s
return h
[docs]
def get_low_pass_iq_filter(num_taps, sampling_frequency, center_frequency, bandwidth):
"""Design complex low-pass filter.
The filter is a low-pass FIR filter modulated to the center frequency.
Args:
num_taps (int): number of taps in filter.
sampling_frequency (float): sample frequency.
center_frequency (float): center frequency.
bandwidth (float): bandwidth in Hz.
Raises:
ValueError: if cutoff frequency (bandwidth / 2) is not within (0, sampling_frequency / 2)
Returns:
ndarray: Complex-valued low-pass filter
"""
cutoff = bandwidth / 2
if not (0 < cutoff < sampling_frequency / 2):
raise ValueError(
f"Cutoff frequency must be within (0, sampling_frequency / 2), "
f"got {cutoff} Hz, must be within (0, {sampling_frequency / 2}) Hz"
)
# Design real-valued low-pass filter
lpf = scipy.signal.firwin(num_taps, cutoff, pass_zero=True, fs=sampling_frequency)
# Modulate to center frequency to make it complex
time_points = np.arange(num_taps) / sampling_frequency
lpf_complex = lpf * np.exp(1j * 2 * np.pi * center_frequency * time_points)
return lpf_complex
[docs]
def complex_to_channels(complex_data, axis=-1):
"""Unroll complex data to separate channels.
Args:
complex_data (complex ndarray): complex input data.
axis (int, optional): on which axis to extend. Defaults to -1.
Returns:
ndarray: real array with real and imaginary components
unrolled over two channels at axis.
"""
# assert ops.iscomplex(complex_data).any()
q_data = ops.imag(complex_data)
i_data = ops.real(complex_data)
i_data = ops.expand_dims(i_data, axis=axis)
q_data = ops.expand_dims(q_data, axis=axis)
iq_data = ops.concatenate((i_data, q_data), axis=axis)
return iq_data
[docs]
def channels_to_complex(data):
"""Convert array with real and imaginary components at
different channels to complex data array.
Args:
data (ndarray): input data, with at 0 index of axis
real component and 1 index of axis the imaginary.
Returns:
ndarray: complex array with real and imaginary components.
"""
assert data.shape[-1] == 2, "Data must have two channels."
data = ops.cast(data, "complex64")
return data[..., 0] + 1j * data[..., 1]
[docs]
def hilbert(x, N: int = None, axis=-1):
"""Implementation of the Hilbert transform function that computes the analytical signal.
Operates in the Fourier domain by applying a filter that zeros out negative frequencies
and doubles positive frequencies.
.. note::
This is NOT the mathematical Hilbert transform as defined in the
`Wikipedia article <https://en.wikipedia.org/wiki/Hilbert_transform>`_,
but instead computes the analytical signal. The implementation reproduces
the behavior of the :func:`scipy.signal.hilbert` function.
Args:
x (ndarray): Input data of any shape.
N (int, optional): Number of points to use for the FFT. If specified and greater
than the length of the data along the specified axis, the data will be
zero-padded. If None, uses the length of x along the specified axis.
Defaults to None.
axis (int, optional): Axis along which to compute the Hilbert transform.
Defaults to -1 (last axis).
Returns:
ndarray: Complex analytical signal with the same shape as the input (or padded
to length N if specified). The real part is the original signal and the
imaginary part is the Hilbert transform of the signal.
Raises:
ValueError: If N is specified and is less than the length of x along the
specified axis.
Example:
>>> import numpy as np
>>> from zea.func import hilbert
>>> x = np.array([1.0, 2.0, 3.0, 4.0])
>>> analytical_signal = hilbert(x)
>>> envelope = np.abs(analytical_signal)
"""
input_shape = x.shape
n_dim = len(input_shape)
n_ax = input_shape[axis]
if axis < 0:
axis = n_dim + axis
if N is not None:
if N < n_ax:
raise ValueError(f"N must be greater or equal to n_ax, got N={N}, n_ax={n_ax}")
pad = np.maximum(N - n_ax, 0)
pad_list = [[0, 0] for _ in range(n_dim)]
pad_list[axis] = [0, pad]
x = ops.pad(x, pad_list, mode="constant", constant_values=0.0)
else:
N = n_ax
# Create filter to zero out negative frequencies
# h[0] = 1, h[1:N//2] = 2, h[N//2] = 1 (if even), rest = 0
indices = ops.arange(N, dtype="float32")
h = ops.zeros(N, dtype="float32")
h = ops.where(indices == 0, 1.0, h)
h = ops.where((indices > 0) & (indices < N / 2.0), 2.0, h)
h = ops.where((N % 2 == 0) & (indices == N / 2.0), 1.0, h)
h = ops.cast(h, "complex64")
idx = list(range(n_dim))
# make sure axis gets to the end for fft (operates on last axis)
idx.remove(axis)
idx.append(axis)
x = ops.transpose(x, idx)
if x.ndim > 1:
h = ops.reshape(h, [1] * (x.ndim - 1) + [-1])
h = h + 1j * ops.zeros_like(h)
Xf_r, Xf_i = ops.fft((x, ops.zeros_like(x)))
Xf_r = ops.cast(Xf_r, "complex64")
Xf_i = ops.cast(Xf_i, "complex64")
Xf = Xf_r + 1j * Xf_i
Xf = Xf * h
# x = np.fft.ifft(Xf)
# do manual ifft using fft
Xf_r = ops.real(Xf)
Xf_i = ops.imag(Xf)
Xf_r_inv, Xf_i_inv = ops.fft((Xf_r, -Xf_i))
Xf_i_inv = ops.cast(Xf_i_inv, "complex64")
Xf_r_inv = ops.cast(Xf_r_inv, "complex64")
N = ops.cast(N, "complex64")
x = Xf_r_inv / N
x = x + 1j * (-Xf_i_inv / N)
# switch back to original shape
idx = list(range(n_dim))
idx.insert(axis, idx.pop(-1))
x = ops.transpose(x, idx)
return x
[docs]
def demodulate(data, demodulation_frequency, sampling_frequency, axis=-3):
"""Demodulates the input data to baseband. The function computes the analytical
signal (the signal with negative frequencies removed) and then shifts the spectrum
of the signal to baseband by multiplying with a complex exponential. Where the
spectrum was centered around `center_frequency` before, it is now centered around
0 Hz. The baseband IQ data are complex-valued. The real and imaginary parts
are stored in two real-valued channels.
Args:
data (ops.Tensor): The input data to demodulate of shape `(..., axis, ..., 1)`.
demodulation_frequency (float): The center frequency of the signal.
sampling_frequency (float): The sampling frequency of the signal.
axis (int, optional): The axis along which to demodulate. Defaults to -3.
Returns:
ops.Tensor: The demodulated IQ data of shape `(..., axis, ..., 2)`.
"""
# Compute the analytical signal
analytical_signal = hilbert(data, axis=axis)
# Define frequency indices
frequency_indices = ops.arange(analytical_signal.shape[axis])
# Expand the frequency indices to match the shape of the RF data
indexing = [None] * data.ndim
indexing[axis] = slice(None)
indexing = tuple(indexing)
frequency_indices_shaped_like_rf = frequency_indices[indexing]
# Cast to complex64
demodulation_frequency = ops.cast(demodulation_frequency, dtype="complex64")
sampling_frequency = ops.cast(sampling_frequency, dtype="complex64")
frequency_indices_shaped_like_rf = ops.cast(frequency_indices_shaped_like_rf, dtype="complex64")
# Shift to baseband
phasor_exponent = (
-1j
* 2
* np.pi
* demodulation_frequency
* frequency_indices_shaped_like_rf
/ sampling_frequency
)
iq_data_signal_complex = analytical_signal * ops.exp(phasor_exponent)
# Split the complex signal into two channels
iq_data_two_channel = complex_to_channels(ops.squeeze(iq_data_signal_complex, axis=-1))
return iq_data_two_channel
[docs]
def compute_time_to_peak_stack(waveforms, center_frequencies, waveform_sampling_frequency=250e6):
"""Compute the time of the peak of each waveform in a stack of waveforms.
Args:
waveforms (ndarray): The waveforms of shape (n_waveforms, n_samples).
center_frequencies (ndarray): The center frequencies of the waveforms in Hz of shape
(n_waveforms,) or a scalar if all waveforms have the same center frequency.
waveform_sampling_frequency (float): The sampling frequency of the waveforms in Hz.
Returns:
ndarray: The time to peak for each waveform in seconds.
"""
t_peak = []
center_frequencies = center_frequencies * ops.ones((waveforms.shape[0],))
for waveform, center_frequency in zip(waveforms, center_frequencies):
t_peak.append(compute_time_to_peak(waveform, center_frequency, waveform_sampling_frequency))
return ops.stack(t_peak)
[docs]
def compute_time_to_peak(waveform, center_frequency, waveform_sampling_frequency=250e6):
"""Compute the time of the peak of the waveform.
Args:
waveform (ndarray): The waveform of shape (n_samples).
center_frequency (float): The center frequency of the waveform in Hz.
waveform_sampling_frequency (float): The sampling frequency of the waveform in Hz.
Returns:
float: The time to peak for the waveform in seconds.
"""
n_samples = waveform.shape[0]
if n_samples == 0:
raise ValueError("Waveform has zero samples.")
waveforms_iq_complex_channels = demodulate(
waveform[..., None], center_frequency, waveform_sampling_frequency, axis=-1
)
waveforms_iq_complex = channels_to_complex(waveforms_iq_complex_channels)
envelope = ops.abs(waveforms_iq_complex)
peak_idx = ops.argmax(envelope, axis=-1)
t_peak = ops.cast(peak_idx, dtype="float32") / waveform_sampling_frequency
return t_peak
[docs]
def envelope_detect(data, axis=-3):
"""Envelope detection of RF signals.
If the input data is real, it first applies the Hilbert transform along the specified axis
and then computes the magnitude of the resulting complex signal.
If the input data is complex, it computes the magnitude directly.
Args:
- data (Tensor): The beamformed data of shape (..., grid_size_z, grid_size_x, n_ch).
- axis (int): Axis along which to apply the Hilbert transform. Defaults to -3.
Returns:
- envelope_data (Tensor): The envelope detected data
of shape (..., grid_size_z, grid_size_x).
"""
if data.shape[-1] == 2:
data = channels_to_complex(data)
else:
n_ax = ops.shape(data)[axis]
# Calculate next power of 2: M = 2^ceil(log2(n_ax))
# see https://github.com/tue-bmd/zea/discussions/147
log2_n_ax = np.log2(n_ax)
M = int(2 ** np.ceil(log2_n_ax))
data = hilbert(data, N=M, axis=axis)
indices = ops.arange(n_ax)
data = ops.take(data, indices, axis=axis)
data = ops.squeeze(data, axis=-1)
data = ops.abs(data)
return data
[docs]
def log_compress(data, eps=1e-16):
"""Apply logarithmic compression to data."""
eps = ops.convert_to_tensor(eps, dtype=data.dtype)
data = ops.where(data == 0, eps, data) # Avoid log(0)
return 20 * ops.log10(data)
[docs]
def make_tgc_curve(n_ax, attenuation_coef, sampling_frequency, center_frequency, sound_speed=1540):
"""
Create a Time Gain Compensation (TGC) curve to compensate for depth-dependent attenuation.
Args:
n_ax (int): Number of samples in the axial direction
attenuation_coef (float): Attenuation coefficient in dB/cm/MHz.
For example, typical value for soft tissue is around 0.5 to 0.75 dB/cm/MHz.
sampling_frequency (float): Sampling frequency in Hz
center_frequency (float): Center frequency in Hz
sound_speed (float): Speed of sound in m/s (default: 1540)
Returns:
np.ndarray: TGC gain curve of shape (n_ax,) in linear scale
"""
# Time vector for each sample
t = np.arange(n_ax) / sampling_frequency # seconds
# Distance traveled (round trip, so divide by 2)
dist = (t * sound_speed) / 2 # meters
# Convert distance to cm
dist_cm = dist * 100
# Attenuation in dB (two-way: transmit + receive)
attenuation_db = 2 * attenuation_coef * dist_cm * (center_frequency * 1e-6)
# Convert dB to linear scale (TGC gain curve)
tgc_gain_curve = 10 ** (attenuation_db / 20)
return tgc_gain_curve.astype(np.float32)