zea.backend¶
Backend-specific utilities.
This subpackage provides backend-specific utilities for the zea library. Most backend logic is handled by Keras 3, but a few features require custom wrappers to ensure compatibility and performance across JAX, TensorFlow, and PyTorch.
Note
Most backend-specific logic is handled by Keras 3, so this subpackage is intentionally minimal. Only features not natively supported by Keras (such as JIT and autograd) are implemented here.
Key Features¶
JIT Compilation (
zea.backend.jit()): Provides a unified interface for just-in-time (JIT) compilation of functions, dispatching to the appropriate backend (JAX or TensorFlow) as needed. This enables accelerated execution of computationally intensive routines. Note that jit compilation is not yet supported when using the torch backend.Automatic Differentiation (
zea.backend.AutoGrad): Offers a backend-agnostic wrapper for automatic differentiation, allowing gradient computation regardless of the underlying ML library.Backend Submodules:
zea.backend.jax– JAX-specific utilities and device management.zea.backend.torch– PyTorch-specific utilities and device management.zea.backend.tensorflow– TensorFlow-specific utilities, and device management, as well as data loading utilities.
Data Loading (
zea.backend.tensorflow.make_dataloader()): This function is implemented using TensorFlow’s efficient data pipeline utilities. It provides a convenient way to load and preprocess data for machine learning workflows, leveraging TensorFlow’stf.data.DatasetAPI.
Functions
|
Moves all tensor arguments of a function to a specified device before calling it. |
|
Applies JIT compilation to the given function based on the current Keras backend. |
|
Applies default tf.function to the given function. |
Classes
|
Context manager to set the device regardless of backend. |
- zea.backend.func_on_device(func, device, *args, **kwargs)[source]¶
Moves all tensor arguments of a function to a specified device before calling it.
- Parameters:
func (callable) – Function to be called.
device (str) – Device to move tensors to.
*args – Positional arguments to be passed to the function.
**kwargs – Keyword arguments to be passed to the function.
- Returns:
The output of the function.
- zea.backend.jit(func=None, jax=True, tensorflow=True, **kwargs)[source]¶
Applies JIT compilation to the given function based on the current Keras backend. Can be used as a decorator or as a function.
- Parameters:
func (callable) – The function to be JIT compiled.
jax (bool) – Whether to enable JIT compilation in the JAX backend.
tensorflow (bool) – Whether to enable JIT compilation in the TensorFlow backend.
**kwargs – Keyword arguments to be passed to the JIT compiler.
- Returns:
The JIT-compiled function.
- Return type:
callable
- class zea.backend.on_device(device)[source]¶
Bases:
objectContext manager to set the device regardless of backend.
For the torch backend, you need to manually move the model and data to the device before using this context manager.
- Parameters:
device (
str) – Device string, e.g.'cuda','gpu', or'cpu'.
Example
with zea.backend.on_device("gpu:3"): pipeline = zea.Pipeline([zea.ops.Abs()]) output = pipeline(data=keras.random.normal((10, 10))) # output is on "cuda:3"
- zea.backend.tf_function(func=None, jit_compile=False, **kwargs)[source]¶
Applies default tf.function to the given function. Only in TensorFlow backend.
Modules
Autograd wrapper for different backends. |
|
Jax utilities for zea. |
|
Simple implementation of optimizers that support multi-backend autodiff. |
|
Tensorflow Ultrasound Beamforming Library. |
|
Pytorch Ultrasound Beamforming Library. |