Source code for zea.utils

"""General utility functions."""

import collections.abc
import datetime
import time
from functools import wraps
from statistics import mean, median, stdev

import keras
import yaml

from zea import log


[docs] def canonicalize_axis(axis, num_dims) -> int: """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims).""" if not -num_dims <= axis < num_dims: raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}") if axis < 0: axis = axis + num_dims return axis
[docs] def map_negative_indices(indices: list, num_dims: int): """Maps negative indices for array indexing to positive indices. Example: >>> from zea.utils import map_negative_indices >>> map_negative_indices([-1, -2], 5) [4, 3] """ return [canonicalize_axis(idx, num_dims) for idx in indices]
[docs] def strtobool(val: str): """Convert a string representation of truth to True or False. True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if 'val' is anything else. """ assert isinstance(val, str), f"Input value must be a string, not {type(val)}" val = val.lower() if val in ("y", "yes", "t", "true", "on", "1"): return True elif val in ("n", "no", "f", "false", "off", "0"): return False else: raise ValueError(f"invalid truth value {val}")
[docs] def update_dictionary(dict1: dict, dict2: dict, keep_none: bool = False) -> dict: """Updates dict1 with values dict2 Args: dict1 (dict): base dictionary dict2 (dict): update dictionary keep_none (bool, optional): whether to keep keys with None values in dict2. Defaults to False. Returns: dict: updated dictionary """ if not keep_none: dict2 = {k: v for k, v in dict2.items() if v is not None} # dict merging python > 3.9: default_scan_params | config_scan_params dict_out = {**dict1, **dict2} return dict_out
[docs] def get_date_string(string: str = None): """Generate a date string for current time, according to format specified by `string`. Refer to the documentation of the datetime module for more information on the formatting options. If no string is specified, the default format is used: "%Y_%m_%d_%H%M%S". """ if string is not None and not isinstance(string, str): raise TypeError("Input must be a string.") # Get the current time now = datetime.datetime.now() # If no string is specified, use the default format if string is None: string = "%Y_%m_%d_%H%M%S" # Generate the date string date_str = now.strftime(string) return date_str
[docs] def date_string_to_readable(date_string: str, include_time: bool = False): """Converts a date string to a more readable format. Args: date_string (str): The input date string. include_time (bool, optional): Whether to include the time in the output. Defaults to False. Returns: str: The date string in a more readable format. """ date = datetime.datetime.strptime(date_string, "%Y_%m_%d_%H%M%S") if include_time: return date.strftime("%B %d, %Y %I:%M %p") else: return date.strftime("%B %d, %Y")
[docs] def deep_compare(obj1, obj2): """Recursively compare two objects for equality.""" # Only recurse into dicts if isinstance(obj1, dict) and isinstance(obj2, dict): if obj1.keys() != obj2.keys(): return False return all(deep_compare(obj1[k], obj2[k]) for k in obj1) # If not dict, but both are iterable (excluding strings/bytes), compare items if ( isinstance(obj1, collections.abc.Iterable) and isinstance(obj2, collections.abc.Iterable) and not isinstance(obj1, (str, bytes)) and not isinstance(obj2, (str, bytes)) ): return all(deep_compare(a, b) for a, b in zip(obj1, obj2)) # Fallback to direct comparison (also handles int, float, str, etc.) return obj1 == obj2
[docs] def block_until_ready(func): """Decorator that ensures asynchronous (gpu) operations complete before returning.""" if keras.backend.backend() == "jax": import jax def _block(value): if hasattr(value, "__array__"): return jax.block_until_ready(value) else: return value else: def _block(value): if hasattr(value, "__array__"): # convert to numpy but return as original type _ = keras.ops.convert_to_numpy(value) return value @wraps(func) def wrapper(*args, **kwargs): result = func(*args, **kwargs) # Handle different return types if isinstance(result, (list, tuple)): # For multiple outputs, block each one blocked_results = [_block(r) for r in result] return type(result)(blocked_results) elif isinstance(result, dict): # For dict outputs, block array values return {k: _block(v) for k, v in result.items()} else: # Single output return _block(result) return wrapper
[docs] class FunctionTimer: """ A decorator class for timing the execution of functions. Example: .. doctest:: >>> from zea.utils import FunctionTimer >>> timer = FunctionTimer() >>> my_function = lambda: sum(range(10)) >>> my_function = timer(my_function, name="my_function") >>> _ = my_function() >>> print(timer.get_stats("my_function")) # doctest: +ELLIPSIS {'mean': ..., 'median': ..., 'std_dev': ..., 'min': ..., 'max': ..., 'count': ...} """ def __init__(self): self.timings = {} self.last_append = 0 self.decorated_functions = {} # Track decorated functions
[docs] def __call__(self, func, name=None): _name = name if name is not None else func.__name__ # Create a unique identifier for this function func_id = id(func) # Check if this exact function has already been decorated if func_id in self.decorated_functions: existing_name = self.decorated_functions[func_id] raise ValueError( f"Function '{func.__name__}' (id: {func_id}) has already been " f"decorated with timer name '{existing_name}'. " f"Cannot decorate the same function instance multiple times." ) # Handle name conflicts by appending a suffix original_name = _name counter = 1 while _name in self.timings: _name = f"{original_name}_{counter}" counter += 1 # Initialize timing storage for this function self.timings[_name] = [] # Track this decorated function self.decorated_functions[func_id] = _name func = block_until_ready(func) @wraps(func) def wrapper(*args, **kwargs): start_time = time.perf_counter() result = func(*args, **kwargs) end_time = time.perf_counter() elapsed_time = end_time - start_time # Store the timing result self.timings[_name].append(elapsed_time) return result return wrapper
def _parse_drop_first(self, drop_first: bool | int): if isinstance(drop_first, bool): idx = 1 if drop_first else 0 elif isinstance(drop_first, int): idx = drop_first else: raise ValueError("drop_first must be a boolean or an integer.") return idx
[docs] def get_stats(self, func_name, drop_first: bool | int = False): """Calculate statistics for the given function.""" if func_name not in self.timings: raise ValueError(f"No timings recorded for function '{func_name}'.") idx = self._parse_drop_first(drop_first) times = self.timings[func_name][idx:] return { "mean": mean(times), "median": median(times), "std_dev": stdev(times) if len(times) > 1 else 0, "min": min(times), "max": max(times), "count": len(times), }
[docs] def export_to_yaml(self, filename): """Export the timing data to a YAML file.""" with open(filename, "w", encoding="utf-8") as f: yaml.dump(self.timings, f, default_flow_style=False) print(f"Timing data exported to '{filename}'.")
[docs] def append_to_yaml(self, filename, func_name): """Append the timing data to a YAML file.""" cropped_timings = self.timings[func_name][self.last_append :] with open(filename, "a", encoding="utf-8") as f: yaml.dump(cropped_timings, f, default_flow_style=False) self.last_append = len(self.timings[func_name])
[docs] def print(self, drop_first: bool | int = False, total_time: bool = False): """Print timing statistics for all recorded functions using formatted output.""" # Print title print(log.bold("Function Timing Statistics")) header = ( f"{log.cyan('Function'):<30} {log.green('Mean'):<22} " f"{log.green('Median'):<22} {log.green('Std Dev'):<22} " f"{log.yellow('Min'):<22} {log.yellow('Max'):<22} {log.magenta('Count'):<18}" ) length = len(log.remove_color_escape_codes(header)) print("=" * length) # Print header print(header) print("-" * length) # Print data rows for func_name in self.timings.keys(): stats = self.get_stats(func_name, drop_first=drop_first) row = ( f"{log.cyan(func_name):<30} " f"{log.green(log.number_to_str(stats['mean'], 6)):<22} " f"{log.green(log.number_to_str(stats['median'], 6)):<22} " f"{log.green(log.number_to_str(stats['std_dev'], 6)):<22} " f"{log.yellow(log.number_to_str(stats['min'], 6)):<22} " f"{log.yellow(log.number_to_str(stats['max'], 6)):<22} " f"{log.magenta(str(stats['count'])):<18}" ) print(row) if total_time: idx = self._parse_drop_first(drop_first) total = sum(mean(times[idx:]) for times in self.timings.values()) print("-" * length) print(f"{log.bold('Mean Total Time:')} {log.bold(log.number_to_str(total, 6))} seconds")