Source code for zea.backend.torch

"""Pytorch Ultrasound Beamforming Library.

Initialize modules for registries.
"""

import torch


[docs] def func_on_device(func, device, *args, **kwargs): """Moves all tensor arguments of a function to a specified device before calling it. Args: func (callable): Function to be called. device (str or torch.device): Device to move tensors to. *args: Positional arguments to be passed to the function. **kwargs: Keyword arguments to be passed to the function. Returns: The output of the function. """ if device is None: return func(*args, **kwargs) if isinstance(device, str): device = torch.device(device) def move_to_device(x): if isinstance(x, torch.Tensor): return x.to(device) elif isinstance(x, (list, tuple)): return type(x)(move_to_device(i) for i in x) elif isinstance(x, dict): return {k: move_to_device(v) for k, v in x.items()} else: return x args = move_to_device(args) kwargs = move_to_device(kwargs) return func(*args, **kwargs)