zea.func¶
Functional API of zea.
This module provides a collection of functions for various operations on tensors
and ultrasound data. These functions can be used standalone, in contrast to the zea.ops module which provides operation classes for building processing pipelines.
Functions
|
L1 norm of a real tensor. |
|
L2 norm of a real tensor. |
|
Adds salt and pepper noise to the input image. |
|
Apply a function to 1D array slices along an axis. |
|
Compute the batch covariance matrices of the input tensor. |
|
Apply a boolean mask to a tensor. |
|
Checks if patches with overlap fit an integer amount in the original image. |
|
Compute required overlap between patches to cover the entire image. |
|
Compute required patch shape to cover the entire image. |
|
Complex correlation via splitting real and imaginary parts. |
|
Extend the number of dimensions of an array. |
|
Extract contour/boundary points from a binary mask using edge detection. |
|
Should be similar to: https://pytorch.org/docs/stable/generated/torch.flatten.html |
|
For loop allowing for non-jitted for loop with same signature as jax. |
|
Wraps a function to apply it to an input tensor with one or more batch dimensions. |
|
Multidimensional Gaussian filter. |
|
1-D Gaussian filter. |
|
Creates patches from images. |
|
Interpolate subsampled data along a specified axis using map_coordinates. |
Distinguish between jax.random.PRNGKey() and jax.random.key() |
|
|
Checks if a given 1D array is monotonic. |
|
Greedy linear sum assignment. |
|
Interpolates a 1D array of indices with gaps. |
|
Compute the power of a square matrix. |
|
Return the indices of the elements that are non-zero. |
|
Normalize data to a given range. |
|
Pad an array to be divisible by N along the specified axis. |
|
Reconstructs images from patches. |
|
Resample tensor along axis. |
|
Reshape data along axis. |
|
Like ops.map but no tracing or jit compilation. |
|
Sinc function. |
|
Split a seed into n seeds for reproducible random ops. |
|
Splits previously stacked tensor data back to its original shape. |
|
Stacks tensor data along a specified stack axis. |
|
Map values in array from one range to other. |
|
vmap with batching or chunking support to avoid memory issues. |
|
Convert array with real and imaginary components at different channels to complex data array. |
|
Unroll complex data to separate channels. |
|
Compute the time of the peak of the waveform. |
|
Compute the time of the peak of each waveform in a stack of waveforms. |
|
Demodulates the input data to baseband. |
|
Demodulates an RF signal to complex base-band (IQ). |
|
Envelope detection of RF signals. |
|
Band pass filter |
|
Design complex low-pass filter. |
|
Implementation of the Hilbert transform function that computes the analytical signal. |
|
Upsamples and upmixes complex base-band signals (IQ) to RF. |
|
Apply logarithmic compression to data. |
|
Create a Time Gain Compensation (TGC) curve to compensate for depth-dependent attenuation. |
- zea.func.L1(x)[source]¶
L1 norm of a real tensor.
Implementation of L1 norm for real vectors: https://mathworld.wolfram.com/L1-Norm.html
- zea.func.L2(x)[source]¶
L2 norm of a real tensor.
Implementation of L2 norm for real vectors: https://mathworld.wolfram.com/L2-Norm.html
- zea.func.add_salt_and_pepper_noise(image, salt_prob, pepper_prob=None, seed=None)[source]¶
Adds salt and pepper noise to the input image.
- Parameters:
image (ndarray) – The input image, must be of type float32 and normalized between 0 and 1.
salt_prob (float) – The probability of adding salt noise to each pixel.
pepper_prob (float, optional) – The probability of adding pepper noise to each pixel. If not provided, it will be set to the same value as salt_prob.
seed – A Python integer or instance of keras.random.SeedGenerator. Used to make the behavior of the initializer deterministic. Note that an initializer seeded with an integer or None (unseeded) will produce the same random values across multiple calls. To get different random values across multiple calls, use as seed an instance of keras.random.SeedGenerator.
- Returns:
The noisy image with salt and pepper noise added.
- Return type:
ndarray
- zea.func.apply_along_axis(func1d, axis, arr, *args, **kwargs)[source]¶
Apply a function to 1D array slices along an axis.
Keras implementation of
numpy.apply_along_axis. Copies thejaximplementation, which usesvmapto vectorize the function application along the specified axis.- Parameters:
func1d – A callable function with signature
func1d(arr, /, *args, **kwargs)where*argsand**kwargsare the additional positional and keyword arguments passed to apply_along_axis.axis – Integer axis along which to apply the function.
arr – The array over which to apply the function.
*args – Additional positional arguments passed through to func1d.
**kwargs – Additional keyword arguments passed through to func1d.
- Returns:
The result of func1d applied along the specified axis.
- zea.func.batch_cov(x, rowvar=True, bias=False, ddof=None)[source]¶
Compute the batch covariance matrices of the input tensor.
- Parameters:
x (Tensor) – Input tensor of shape (…, m, n) where m is the number of features and n is the number of observations.
rowvar (bool, optional) – If True, each row represents a variable, while each column represents an observation. If False, each column represents a variable, while each row represents an observation. Defaults to True.
bias (bool, optional) – If True, the biased estimator of the covariance is computed. If False, the unbiased estimator is computed. Defaults to False.
ddof (int, optional) – Delta degrees of freedom. The divisor used in the calculation is (num_obs - ddof), where num_obs is the number of observations. If ddof is not specified, it is set to 0 if bias is True, and 1 if bias is False. Defaults to None.
- Returns:
- Batch covariance matrices of shape (…, m, m) if rowvar=True,
or (…, n, n) if rowvar=False.
- Return type:
Tensor
- zea.func.boolean_mask(tensor, mask, size=None)[source]¶
Apply a boolean mask to a tensor.
- Parameters:
tensor (Tensor) – The input tensor.
mask (Tensor) – The boolean mask to apply.
size (int, optional) – The size of the output tensor. Only used for Jax backend if you want to trace the function. Defaults to None.
- Returns:
The masked tensor.
- Return type:
Tensor
- zea.func.channels_to_complex(data)[source]¶
Convert array with real and imaginary components at different channels to complex data array.
- Parameters:
data (ndarray) – input data, with at 0 index of axis real component and 1 index of axis the imaginary.
- Returns:
complex array with real and imaginary components.
- Return type:
ndarray
- zea.func.check_patches_fit(image_shape, patch_shape, overlap)[source]¶
Checks if patches with overlap fit an integer amount in the original image.
- Parameters:
image_shape (
tuple) – A tuple representing the shape of the original image.patch_size – A tuple representing the shape of the patches.
overlap (
Union[int,Tuple[int,int]]) – A float representing the overlap between patches.
- Return type:
tuple- Returns:
A tuple containing a boolean indicating if the patches fit an integer amount in the original image and the new image shape if the patches do not fit.
Example
>>> from zea.func import check_patches_fit >>> image_shape = (10, 10) >>> patch_shape = (4, 4) >>> overlap = (2, 2) >>> patches_fit, new_shape = check_patches_fit(image_shape, patch_shape, overlap) >>> patches_fit True >>> new_shape (10, 10)
- zea.func.complex_to_channels(complex_data, axis=-1)[source]¶
Unroll complex data to separate channels.
- Parameters:
complex_data (complex ndarray) – complex input data.
axis (int, optional) – on which axis to extend. Defaults to -1.
- Returns:
- real array with real and imaginary components
unrolled over two channels at axis.
- Return type:
ndarray
- zea.func.compute_required_patch_overlap(image_shape, patch_shape)[source]¶
Compute required overlap between patches to cover the entire image.
- Parameters:
image_shape – Tuple of (height, width)
patch_shape – Tuple of (patch_height, patch_width)
- Returns:
Tuple of (overlap_y, overlap_x)
Or None if there is no overlap that will result in integer number of patches given the image and patch shapes.
- zea.func.compute_required_patch_shape(image_shape, patch_shape, overlap)[source]¶
Compute required patch shape to cover the entire image.
Compute patch_shape closest to the original patch_shape that will result in integer number of patches given the image and overlap.
- Parameters:
image_shape – Tuple of (height, width)
patch_shape – Tuple of (patch_height, patch_width)
overlap – Tuple of (overlap_y, overlap_x)
- Returns:
Tuple of (patch_shape_y, patch_shape_x)
or None if there is no patch_shape that will result in integer number of patches given the image and overlap.
- zea.func.compute_time_to_peak(waveform, center_frequency, waveform_sampling_frequency=250000000.0)[source]¶
Compute the time of the peak of the waveform.
- Parameters:
waveform (ndarray) – The waveform of shape (n_samples).
center_frequency (float) – The center frequency of the waveform in Hz.
waveform_sampling_frequency (float) – The sampling frequency of the waveform in Hz.
- Returns:
The time to peak for the waveform in seconds.
- Return type:
float
- zea.func.compute_time_to_peak_stack(waveforms, center_frequencies, waveform_sampling_frequency=250000000.0)[source]¶
Compute the time of the peak of each waveform in a stack of waveforms.
- Parameters:
waveforms (ndarray) – The waveforms of shape (n_waveforms, n_samples).
center_frequencies (ndarray) – The center frequencies of the waveforms in Hz of shape (n_waveforms,) or a scalar if all waveforms have the same center frequency.
waveform_sampling_frequency (float) – The sampling frequency of the waveforms in Hz.
- Returns:
The time to peak for each waveform in seconds.
- Return type:
ndarray
- zea.func.correlate(x, y, mode='full')[source]¶
Complex correlation via splitting real and imaginary parts. Equivalent to np.correlate(x, y, mode).
NOTE: this function exists because tensorflow does not support complex correlation. NOTE: tensorflow also handles padding differently than numpy, so we manually pad the input.
- Parameters:
x – np.ndarray (complex or real)
y – np.ndarray (complex or real)
mode – “full”, “valid”, or “same”
- zea.func.demodulate(data, demodulation_frequency, sampling_frequency, axis=-3)[source]¶
Demodulates the input data to baseband. The function computes the analytical signal (the signal with negative frequencies removed) and then shifts the spectrum of the signal to baseband by multiplying with a complex exponential. Where the spectrum was centered around center_frequency before, it is now centered around 0 Hz. The baseband IQ data are complex-valued. The real and imaginary parts are stored in two real-valued channels.
- Parameters:
data (ops.Tensor) – The input data to demodulate of shape (…, axis, …, 1).
demodulation_frequency (float) – The center frequency of the signal.
sampling_frequency (float) – The sampling frequency of the signal.
axis (int, optional) – The axis along which to demodulate. Defaults to -3.
- Returns:
The demodulated IQ data of shape (…, axis, …, 2).
- Return type:
ops.Tensor
- zea.func.demodulate_not_jitable(rf_data, sampling_frequency=None, demodulation_frequency=None, bandwidth=None, filter_coeff=None)[source]¶
Demodulates an RF signal to complex base-band (IQ).
Demodulates the radiofrequency (RF) bandpass signals and returns the Inphase/Quadrature (I/Q) components. IQ is a complex whose real (imaginary) part contains the in-phase (quadrature) component.
This function operates (i.e. demodulates) on the RF signal over the (fast-) time axis which is assumed to be the last axis.
- Parameters:
rf_data (ndarray) – real valued input array of size […, n_ax, n_el]. second to last axis is fast-time axis.
sampling_frequency (float) – the sampling frequency of the RF signals (in Hz). Only not necessary when filter_coeff is provided.
demodulation_frequency (float, optional) – Modulation frequency (in Hz).
bandwidth (float, optional) – Bandwidth of RF signal in % of center frequency. Defaults to None. The bandwidth in % is defined by: B = Bandwidth_in_% = Bandwidth_in_Hz*(100/center_frequency). The cutoff frequency: Wn = Bandwidth_in_Hz/sampling_frequency, i.e: Wn = B*(center_frequency/100)/sampling_frequency.
filter_coeff (list, optional) – (b, a), numerator and denominator coefficients of FIR filter for quadratic band pass filter. All other parameters are ignored if filter_coeff are provided. Instead the given filter_coeff is directly used. If not provided, a filter is derived from the other params (sampling_frequency, center_frequency, bandwidth). see https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.lfilter.html
- Returns:
complex valued base-band signal.
- Return type:
iq_data (ndarray)
- zea.func.envelope_detect(data, axis=-3)[source]¶
Envelope detection of RF signals.
If the input data is real, it first applies the Hilbert transform along the specified axis and then computes the magnitude of the resulting complex signal. If the input data is complex, it computes the magnitude directly.
- Parameters:
data (-) – The beamformed data of shape (…, grid_size_z, grid_size_x, n_ch).
axis (-) – Axis along which to apply the Hilbert transform. Defaults to -3.
- Returns:
- The envelope detected data
of shape (…, grid_size_z, grid_size_x).
- Return type:
envelope_data (Tensor)
- zea.func.extend_n_dims(arr, axis, n_dims)[source]¶
Extend the number of dimensions of an array.
Inserts ‘n_dims’ ones at the specified axis.
- Parameters:
arr – The input array.
axis – The axis at which to insert the new dimensions.
n_dims – The number of dimensions to insert.
- Returns:
The array with the extended number of dimensions.
- Raises:
AssertionError – If the axis is out of range.
- zea.func.find_contour(binary_mask)[source]¶
Extract contour/boundary points from a binary mask using edge detection.
This function finds the boundary pixels of objects in a binary mask by detecting pixels that have at least one neighbor with a different value (using 4-connectivity).
- Parameters:
binary_mask – Binary mask tensor of shape (H, W) with values 0 or 1.
- Returns:
Boundary points as tensor of shape (N, 2) in (row, col) format. Returns empty tensor of shape (0, 2) if no boundaries are found.
Example
>>> from zea.func import find_contour >>> import keras >>> mask = keras.ops.zeros((10, 10)) >>> mask = keras.ops.scatter_update( ... mask, [[3, 3], [3, 4], [4, 3], [4, 4]], [1, 1, 1, 1] ... ) >>> contour = find_contour(mask) >>> contour.shape (4, 2)
- zea.func.flatten(tensor, start_dim=0, end_dim=-1)[source]¶
Should be similar to: https://pytorch.org/docs/stable/generated/torch.flatten.html
- zea.func.fori_loop(lower, upper, body_fun, init_val, disable_jit=False)[source]¶
For loop allowing for non-jitted for loop with same signature as jax.
- Parameters:
lower (int) – Lower bound of the loop.
upper (int) – Upper bound of the loop.
body_fun (function) – Function to be executed in the loop.
init_val (any) – Initial value for the loop.
disable_jit (bool, optional) – If True, disables JIT compilation. Defaults to False.
- zea.func.func_with_one_batch_dim(func, tensor, n_batch_dims, batch_size=None, func_axis=None, **kwargs)[source]¶
Wraps a function to apply it to an input tensor with one or more batch dimensions.
The function will be executed in parallel on all batch elements.
- Parameters:
func (function) – The function to apply to the image. Will take the func_axis output from the function.
tensor (Tensor) – The input tensor.
n_batch_dims (
int) – The number of batch dimensions in the input tensor. Expects the input to start with n_batch_dims batch dimensions. Defaults to 2.batch_size (
int|None) – Integer specifying the size of the batch for each step to execute in parallel. Defaults to None, in which case the function will run everything in parallel.func_axis (
int|None) – If func returns mulitple outputs, this axis will be returned.**kwargs – Additional keyword arguments to pass to the function.
- Returns:
The output tensor with the same batch dimensions as the input tensor.
- Raises:
ValueError – If the number of batch dimensions is greater than the rank of the input tensor.
- zea.func.gaussian_filter(array, sigma, order=0, mode='symmetric', cval=None, truncate=4.0, axes=None)[source]¶
Multidimensional Gaussian filter.
If you want to use this function with jax.jit, you can set: static_argnames=(“truncate”, “sigma”)
- Parameters:
array (Tensor) – The input array.
sigma (float or tuple) – Standard deviation for Gaussian kernel. The standard deviations of the Gaussian filter are given for each axis as a sequence, or as a single number, in which case it is equal for all axes.
order (
Union[int,Tuple[int]]) – The order of the filter along each axis is given as a sequence of integers, or as a single number. An order of 0 corresponds to convolution with a Gaussian kernel. A positive order corresponds to convolution with that derivative of a Gaussian. Default is 0.mode (
str) – Padding mode for the input image. Default is ‘symmetric’. See [keras docs](https://www.tensorflow.org/api_docs/python/tf/keras/ops/pad) for all options and [tensorflow docs](https://www.tensorflow.org/api_docs/python/tf/pad) for some examples. Note that the naming differs from scipy.ndimage.gaussian_filter!cval (
float|None) – Value to fill past edges of input if mode is ‘constant’. Default is None.truncate (
float) – Truncate the filter at this many standard deviations. Default is 4.0.axes (
Tuple[int]) – If None, input is filtered along all axes. Otherwise, input is filtered along the specified axes. When axes is specified, any tuples used for sigma, order, mode and/or radius must match the length of axes. The ith entry in any of these tuples corresponds to the ith entry in axes.
- zea.func.gaussian_filter1d(array, sigma, axis=-1, order=0, mode='symmetric', truncate=4.0, cval=None)[source]¶
1-D Gaussian filter.
- Parameters:
array (Tensor) – The input array.
sigma (float or tuple) – Standard deviation for Gaussian kernel. The standard deviations of the Gaussian filter are given for each axis as a sequence, or as a single number, in which case it is equal for all axes.
order (int or Tuple[int]) – The order of the filter along each axis is given as a sequence of integers, or as a single number. An order of 0 corresponds to convolution with a Gaussian kernel. A positive order corresponds to convolution with that derivative of a Gaussian. Default is 0.
mode (str, optional) – Padding mode for the input image. Default is ‘symmetric’. See [keras docs](https://www.tensorflow.org/api_docs/python/tf/keras/ops/pad) for all options and [tensorflow docs](https://www.tensorflow.org/api_docs/python/tf/pad) for some examples. Note that the naming differs from scipy.ndimage.gaussian_filter!
cval (float, optional) – Value to fill past edges of input if mode is ‘constant’. Default is None.
truncate (float, optional) – Truncate the filter at this many standard deviations. Default is 4.0.
axes (Tuple[int], optional) – If None, input is filtered along all axes. Otherwise, input is filtered along the specified axes. When axes is specified, any tuples used for sigma, order, mode and/or radius must match the length of axes. The ith entry in any of these tuples corresponds to the ith entry in axes.
- zea.func.get_band_pass_filter(num_taps, sampling_frequency, f1, f2, validate=True)[source]¶
Band pass filter
Compatible with
jax.jitwhennumtapsis static. Based onscipy.signal.firwinwith hamming window.- Parameters:
num_taps (int) – number of taps in filter.
sampling_frequency (float) – sample frequency in Hz.
f1 (float) – cutoff frequency in Hz of left band edge.
f2 (float) – cutoff frequency in Hz of right band edge.
validate (bool, optional) – whether to validate the cutoff frequencies. Defaults to True.
- Returns:
band pass filter
- Return type:
ndarray
- zea.func.get_low_pass_iq_filter(num_taps, sampling_frequency, center_frequency, bandwidth)[source]¶
Design complex low-pass filter.
The filter is a low-pass FIR filter modulated to the center frequency.
- Parameters:
num_taps (int) – number of taps in filter.
sampling_frequency (float) – sample frequency.
center_frequency (float) – center frequency.
bandwidth (float) – bandwidth in Hz.
- Raises:
ValueError – if cutoff frequency (bandwidth / 2) is not within (0, sampling_frequency / 2)
- Returns:
Complex-valued low-pass filter
- Return type:
ndarray
- zea.func.hilbert(x, N=None, axis=-1)[source]¶
Implementation of the Hilbert transform function that computes the analytical signal.
Operates in the Fourier domain by applying a filter that zeros out negative frequencies and doubles positive frequencies.
Note
This is NOT the mathematical Hilbert transform as defined in the Wikipedia article, but instead computes the analytical signal. The implementation reproduces the behavior of the
scipy.signal.hilbert()function.- Parameters:
x (ndarray) – Input data of any shape.
N (
int) – Number of points to use for the FFT. If specified and greater than the length of the data along the specified axis, the data will be zero-padded. If None, uses the length of x along the specified axis. Defaults to None.axis (int, optional) – Axis along which to compute the Hilbert transform. Defaults to -1 (last axis).
- Returns:
- Complex analytical signal with the same shape as the input (or padded
to length N if specified). The real part is the original signal and the imaginary part is the Hilbert transform of the signal.
- Return type:
ndarray
- Raises:
ValueError – If N is specified and is less than the length of x along the specified axis.
Example
>>> import numpy as np >>> from zea.func import hilbert >>> x = np.array([1.0, 2.0, 3.0, 4.0]) >>> analytical_signal = hilbert(x) >>> envelope = np.abs(analytical_signal)
- zea.func.images_to_patches(images, patch_shape, overlap=None)[source]¶
Creates patches from images.
- Parameters:
images (
KerasTensor) – input images [batch, height, width, channels].patch_shape (
Union[int,Tuple[int,int]]) – Height and width of patch. Defaults to 4.overlap (
Union[int,Tuple[int,int]]) – Overlap between patches in px. Defaults to None.
- Returns:
- batch of patches of size:
[batch, #patch_y, #patch_x, patch_size_y, patch_size_x, #channels].
- Return type:
KerasTensor
Example
>>> import keras >>> from zea.func import images_to_patches >>> images = keras.random.uniform((2, 8, 8, 3)) >>> patches = images_to_patches(images, patch_shape=(4, 4), overlap=(2, 2)) >>> patches.shape (2, 3, 3, 4, 4, 3)
- zea.func.interpolate_data(subsampled_data, mask, order=1, axis=-1, fill_mode='nearest', fill_value=0)[source]¶
Interpolate subsampled data along a specified axis using map_coordinates.
- Parameters:
subsampled_data (ndarray) – The data subsampled along the specified axis. Its shape matches mask except along the subsampled axis.
mask (ndarray) – Boolean array with the same shape as the full data. True where data is known.
order (int, optional) – The order of the spline interpolation. Default is 1.
axis (int, optional) – The axis along which the data is subsampled. Default is -1.
fill_mode (str, optional) – Points outside the boundaries of the input are filled according to the given mode. Default is ‘nearest’. For more info see keras.ops.image.map_coordinates.
fill_value (float, optional) – Value to use for points outside the boundaries of the input if fill_mode is ‘constant’. Default is 0. For more info see keras.ops.image.map_coordinates.
- Returns:
The data interpolated back to the original grid.
- Return type:
ndarray
- ValueError: If mask does not indicate any missing data or if mask has False
values along multiple axes.
- zea.func.is_monotonic(array, increasing=True)[source]¶
Checks if a given 1D array is monotonic.
Either entirely non-decreasing or non-increasing.
- Parameters:
array (ndarray) – A 1D numpy array.
- Returns:
True if the array is monotonic, False otherwise.
- Return type:
bool
- zea.func.linear_sum_assignment(cost)[source]¶
Greedy linear sum assignment.
- Parameters:
cost (Tensor) – Cost matrix of shape (n, n).
- Returns:
Row indices and column indices for assignment.
- Return type:
Tuple
Returns row indices and column indices for assignment.
- zea.func.make_tgc_curve(n_ax, attenuation_coef, sampling_frequency, center_frequency, sound_speed=1540)[source]¶
Create a Time Gain Compensation (TGC) curve to compensate for depth-dependent attenuation.
- Parameters:
n_ax (int) – Number of samples in the axial direction
attenuation_coef (float) – Attenuation coefficient in dB/cm/MHz. For example, typical value for soft tissue is around 0.5 to 0.75 dB/cm/MHz.
sampling_frequency (float) – Sampling frequency in Hz
center_frequency (float) – Center frequency in Hz
sound_speed (float) – Speed of sound in m/s (default: 1540)
- Returns:
TGC gain curve of shape (n_ax,) in linear scale
- Return type:
np.ndarray
- zea.func.map_indices_for_interpolation(indices)[source]¶
Interpolates a 1D array of indices with gaps.
Maps a 1D array of indices with gaps to a 1D array where gaps would be between the integers.
Used in the interpolate_data function.
- Parameters:
(indices) – A 1D array of indices with gaps.
- Returns:
mapped to a 1D array where gaps would be between the integers
- Return type:
(indices)
There are two segments here of length 4 and 2
Example
>>> indices = [5, 6, 7, 8, 12, 13, 19] >>> mapped_indices = [5, 5.25, 5.5, 5.75, 8, 8.5, 12.5]
- zea.func.matrix_power(matrix, power)[source]¶
Compute the power of a square matrix.
Should match the [numpy](https://numpy.org/doc/stable/reference/generated/numpy.linalg.matrix_power.html) implementation.
- Parameters:
matrix (array-like) – A square matrix to be raised to a power.
power (int) – The exponent to which the matrix is to be raised. Must be a non-negative integer.
- Returns:
The resulting matrix after raising the input matrix to the specified power.
- Return type:
array-like
- zea.func.nonzero(x, size=None, fill_value=None)[source]¶
Return the indices of the elements that are non-zero.
- zea.func.normalize(data, output_range, input_range=None)[source]¶
Normalize data to a given range.
Equivalent to translate with clipping.
- Parameters:
data (ops.Tensor) – Input data to normalize.
output_range (tuple) – Range to which data should be mapped, e.g., (0, 1).
input_range (tuple, optional) – Range of input data. If None, the range will be computed from the data. Defaults to None.
- zea.func.pad_array_to_divisible(arr, N, axis=0, mode='constant', pad_value=None)[source]¶
Pad an array to be divisible by N along the specified axis.
- Parameters:
arr (Tensor) – The input array to pad.
N (int) – The number to which the length of the specified axis should be divisible.
axis (int, optional) – The axis along which to pad the array. Defaults to 0.
mode (str, optional) – The padding mode to use. Defaults to ‘constant’. One of “constant”, “edge”, “linear_ramp”, “maximum”, “mean”, “median”, “minimum”, “reflect”, “symmetric”, “wrap”, “empty”, “circular”. Defaults to “constant”.
pad_value (float, optional) – The value to use for padding when mode=’constant’. Defaults to None. If mode is not constant, this value should be None.
- Returns:
The padded array.
- Return type:
Tensor
- zea.func.patches_to_images(patches, image_shape, overlap=None, window_type='average')[source]¶
Reconstructs images from patches.
- Parameters:
patches (
KerasTensor) – Array with batch of patches to convert to batch of images. [batch_size, #patch_y, #patch_x, patch_size_y, patch_size_x, n_channels]image_shape (
tuple) – Shape of output image. (height, width, channels)overlap (
Union[int,Tuple[int,int]]) – Overlap between patches in px. Defaults to None.window_type (str, optional) – Type of stitching to use. Defaults to ‘average’. Options: ‘average’, ‘replace’.
- Returns:
Reconstructed batch of images from batch of patches.
- Return type:
KerasTensor
Example
>>> import keras >>> from zea.func import patches_to_images >>> patches = keras.random.uniform((2, 3, 3, 4, 4, 3)) >>> images = patches_to_images(patches, image_shape=(8, 8, 3), overlap=(2, 2)) >>> images.shape (2, 8, 8, 3)
- zea.func.resample(x, n_samples, axis=-2, order=1)[source]¶
Resample tensor along axis.
Similar to scipy.signal.resample.
- Parameters:
x – input tensor.
n_samples – number of samples after resampling.
axis – axis to resample along.
order – interpolation order (1=linear).
- Returns:
Resampled tensor.
- zea.func.reshape_axis(data, newshape, axis)[source]¶
Reshape data along axis.
- Parameters:
data (tensor) – input data.
newshape (
tuple) – new shape of data along axis.axis (
int) – axis to reshape.
Example
>>> import keras >>> from zea.func import reshape_axis >>> data = keras.random.uniform((3, 4, 5)) >>> newshape = (2, 2) >>> reshaped_data = reshape_axis(data, newshape, axis=1) >>> reshaped_data.shape (3, 2, 2, 5)
- zea.func.split_seed(seed, n)[source]¶
Split a seed into n seeds for reproducible random ops.
Supports keras.random.SeedGenerator and JAX random keys.
- Parameters:
seed – None, jax.Array, or keras.random.SeedGenerator.
n (int) – Number of seeds to generate.
- Returns:
List of n seeds (JAX keys, SeedGenerator, or None).
- Return type:
list
- zea.func.split_volume_data_from_axis(data, batch_axis, stack_axis, number, padding)[source]¶
Splits previously stacked tensor data back to its original shape.
This function reverses the operation performed by stack_volume_data_along_axis.
- Parameters:
data (Tensor) – Input tensor to be split.
batch_axis (
int) – Axis along which to restore the blocks.stack_axis (
int) – Axis from which to split the stacked data.number (
int) – Number of slices per stack.padding (
int) – Amount of padding to remove from the result.
- Returns:
Reshaped tensor with data split back to original format.
- Return type:
Tensor
Example
>>> import keras >>> from zea.func import split_volume_data_from_axis >>> data = keras.random.uniform((20, 10, 30)) >>> split_data = split_volume_data_from_axis(data, 0, 1, 2, 2) >>> split_data.shape (39, 5, 30)
- zea.func.stack_volume_data_along_axis(data, batch_axis, stack_axis, number)[source]¶
Stacks tensor data along a specified stack axis.
Stack tensor data along a specified stack axis by splitting it into blocks along the batch axis.
- Parameters:
data (Tensor) – Input tensor to be stacked.
batch_axis (
int) – Axis along which to split the data into blocks.stack_axis (
int) – Axis along which to stack the blocks.number (
int) – Number of slices per stack.
- Returns:
Reshaped tensor with data stacked along stack_axis.
- Return type:
Tensor
Example
>>> import keras >>> from zea.func import stack_volume_data_along_axis >>> data = keras.random.uniform((10, 20, 30)) >>> # stacking along 1st axis with 2 frames per block >>> stacked_data = stack_volume_data_along_axis(data, 0, 1, 2) >>> stacked_data.shape (5, 40, 30)
- zea.func.translate(array, range_from=None, range_to=(0, 255))[source]¶
Map values in array from one range to other.
- Parameters:
array (ndarray) – input array.
range_from (Tuple, optional) – lower and upper bound of original array. Defaults to min and max of array.
range_to (Tuple, optional) – lower and upper bound to which array should be mapped. Defaults to (0, 255).
- Returns:
translated array
- Return type:
(ndarray)
- zea.func.upmix(iq_data, sampling_frequency, demodulation_frequency, upsampling_rate=6)[source]¶
Upsamples and upmixes complex base-band signals (IQ) to RF.
- Parameters:
iq_data (ndarray) – complex valued input array of size […, n_ax, n_el]. second to last axis is fast-time axis.
sampling_frequency (float) – the sampling frequency of the input IQ signal (in Hz). resulting sampling_frequency of RF data is upsampling_rate times higher.
demodulation_frequency (float, optional) – modulation frequency (in Hz).
- Returns:
output real valued rf data.
- Return type:
rf_data (ndarray)
- zea.func.vmap(fun, in_axes=0, out_axes=0, chunks=None, batch_size=None, fn_supports_batch=False, disable_jit=False, _use_torch_vmap=False)[source]¶
vmap with batching or chunking support to avoid memory issues.
Basically a wrapper around vmap that splits the input into batches or chunks to avoid memory issues with large inputs. Choose the batch_size or chunks wisely, because it pads the input to make it divisible, and then crop the output back to the original size.
- Parameters:
fun (
callable) – Function to be mapped.in_axes (
Union[List[Optional[int]],int]) – Axis or axes to be mapped over in the input.out_axes (
Union[List[Optional[int]],int]) – Axis or axes to be mapped over in the output.batch_size (
int|None) – Size of the batch for each step. If None, the function will be equivalent to vmap. If 1, will be equivalent to map. Mutually exclusive with chunks.chunks (
int|None) – Number of chunks to split the input into. If None or 1, the function will be equivalent to vmap. Mutually exclusive with batch_size.fn_supports_batch (
bool) – If True, assumes that fun can already handle batched inputs. In this case, this function will only handle padding and reshaping for batching.disable_jit (
bool) – If True, disables JIT compilation for backends that support it. This can be useful for debugging. Will fall back to simple mapping._use_torch_vmap (
bool) – If True, uses PyTorch’s native vmap implementation. Advantage: you can apply vmap multiple times without issues. Disadvantage: does not support None arguments.
- Returns:
A function that applies fun in a batched manner over the specified axes.