Source code for zea.backend.jax

"""Jax utilities for zea."""

import jax


[docs] def str_to_jax_device(device): """Convert a device string to a JAX device. Args: device (str): Device string, e.g. ``'gpu:0'``, or ``'cpu:0'``. Returns: jax.Device: The corresponding JAX device. """ if not isinstance(device, str): raise ValueError(f"Device must be a string, got {type(device)}") device = device.lower().replace("cuda", "gpu") device = device.split(":") if len(device) == 2: device_type, device_number = device device_number = int(device_number) else: # if no device number is specified, use the first device device_type = device[0] device_number = 0 available = jax.devices(device_type) if len(available) == 0: raise ValueError(f"No JAX devices available for type '{device_type}'.") if device_number < 0 or device_number >= len(available): raise ValueError(f"Device '{device}' is not available; JAX devices found: {available}") return available[device_number]