from typing import List, Tuple, Union
import numpy as np
from keras import ops
from keras.src.layers.preprocessing.data_layer import DataLayer
from zea.func import normalize
from zea.func.tensor import gaussian_filter
from zea.internal.registry import ops_registry
from zea.ops.base import Filter, Operation
from zea.utils import map_negative_indices
[docs]
@ops_registry("gaussian_blur")
class GaussianBlur(Filter):
"""
GaussianBlur is an operation that applies a Gaussian blur to an input image.
Uses scipy.ndimage.gaussian_filter to create a kernel.
"""
def __init__(
self,
sigma: float,
order: int | Tuple[int] = 0,
mode: str = "symmetric",
cval: float | None = None,
truncate: float = 4.0,
axes: Tuple[int] = (-3, -2),
**kwargs,
):
"""
Args:
sigma (float or tuple): Standard deviation for Gaussian kernel. The standard deviations
of the Gaussian filter are given for each axis as a sequence, or as a single number,
in which case it is equal for all axes.
order (int or Tuple[int]): The order of the filter along each axis is given as a
sequence of integers, or as a single number. An order of 0 corresponds to
convolution with a Gaussian kernel. A positive order corresponds to convolution
with that derivative of a Gaussian. Default is 0.
mode (str, optional): Padding mode for the input image. Default is 'symmetric'.
See [keras docs](https://www.tensorflow.org/api_docs/python/tf/keras/ops/pad) for
all options and [tensorflow docs](https://www.tensorflow.org/api_docs/python/tf/pad)
for some examples. Note that the naming differs from scipy.ndimage.gaussian_filter!
cval (float, optional): Value to fill past edges of input if mode is 'constant'.
Default is None.
truncate (float, optional): Truncate the filter at this many standard deviations.
Default is 4.0.
axes (Tuple[int], optional): If None, input is filtered along all axes. Otherwise, input
is filtered along the specified axes. When axes is specified, any tuples used for
sigma, order, mode and/or radius must match the length of axes. The ith entry in
any of these tuples corresponds to the ith entry in axes. Default is (-3, -2),
which corresponds to the height and width dimensions of a
(..., height, width, channels) tensor.
"""
super().__init__(**kwargs)
self.sigma = sigma
self.order = order
self.mode = mode
self.cval = cval
self.truncate = truncate
self.axes = axes
[docs]
def call(self, **kwargs):
"""Apply a Gaussian filter to the input data.
Args:
data (ops.Tensor): Input image data of shape (height, width, channels) with
optional batch dimension if ``self.with_batch_dim``.
"""
data = kwargs[self.key]
axes = self._resolve_filter_axes(data, self.axes)
out = gaussian_filter(
data, self.sigma, self.order, self.mode, self.cval, self.truncate, axes
)
return {self.output_key: out}
[docs]
@ops_registry("normalize")
class Normalize(Operation):
"""Normalize data to a given range."""
ADD_OUTPUT_KEYS = ["minval", "maxval"]
def __init__(self, output_range=None, input_range=None, **kwargs):
super().__init__(**kwargs)
if output_range is None:
output_range = (0, 1)
self.output_range = self.to_float32(output_range)
self.input_range = self.to_float32(input_range)
if len(self.output_range) != 2:
raise ValueError(
f"output_range must have exactly 2 elements, got {len(self.output_range)}"
)
if self.input_range is not None and len(self.input_range) != 2:
raise ValueError(
f"input_range must have exactly 2 elements, got {len(self.input_range)}"
)
[docs]
@staticmethod
def to_float32(data):
"""Converts an iterable to float32 and leaves None values as is."""
return (
[np.float32(x) if x is not None else None for x in data] if data is not None else None
)
@property
def valid_keys(self):
if self.input_range is None:
return super().valid_keys.union({"maxval", "minval"})
else:
return super().valid_keys
[docs]
def call(self, **kwargs):
"""Normalize data to a given range.
Args:
output_range (tuple, optional): Range to which data should be mapped.
Defaults to (0, 1).
input_range (tuple, optional): Range of input data. If None, the range
of the input data will be computed. Defaults to None.
Returns:
dict: Dictionary containing normalized data, along with the computed
or provided input range (minval and maxval).
"""
data = kwargs[self.key]
# If input_range is not provided, try to get it from kwargs
# This allows you to normalize based on the first frame in a sequence and avoid flicker
if self.input_range is None:
maxval = kwargs.get("maxval", None)
minval = kwargs.get("minval", None)
# If input_range is provided, use it
else:
minval, maxval = self.input_range
# If input_range is still not provided, compute it from the data
if minval is None:
minval = ops.min(data)
if maxval is None:
maxval = ops.max(data)
normalized_data = normalize(
data, output_range=self.output_range, input_range=(minval, maxval)
)
return {self.output_key: normalized_data, "minval": minval, "maxval": maxval}
[docs]
@ops_registry("pad")
class Pad(Operation, DataLayer):
"""Pad layer for padding tensors to a specified shape."""
def __init__(
self,
target_shape: list | tuple,
uniform: bool = True,
axis: Union[int, List[int]] = None,
fail_on_bigger_shape: bool = True,
pad_kwargs: dict = None,
**kwargs,
):
super().__init__(**kwargs)
self.target_shape = target_shape
self.uniform = uniform
self.axis = axis
self.pad_kwargs = pad_kwargs or {}
self.fail_on_bigger_shape = fail_on_bigger_shape
@staticmethod
def _format_target_shape(shape_array, target_shape, axis):
if isinstance(axis, int):
axis = [axis]
assert len(axis) == len(target_shape), (
"The length of axis must be equal to the length of target_shape."
)
axis = map_negative_indices(axis, len(shape_array))
target_shape = [
target_shape[axis.index(i)] if i in axis else shape_array[i]
for i in range(len(shape_array))
]
return target_shape
[docs]
def pad(
self,
z,
target_shape: list | tuple,
uniform: bool = True,
axis: Union[int, List[int]] = None,
fail_on_bigger_shape: bool = True,
**kwargs,
):
"""
Pads the input tensor `z` to the specified shape.
Parameters:
z (tensor): The input tensor to be padded.
target_shape (list or tuple): The target shape to pad the tensor to.
uniform (bool, optional): If True, ensures that padding is uniform (even on both sides).
Default is False.
axis (int or list of int, optional): The axis or axes along which `target_shape` was
specified. If None, `len(target_shape) == `len(ops.shape(z))` must hold.
Default is None.
fail_on_bigger_shape (bool, optional): If True (default), raises an error if any target
dimension is smaller than the input shape; if False, pads only where the
target shape exceeds the input shape and leaves other dimensions unchanged.
kwargs: Additional keyword arguments to pass to the padding function.
Returns:
tensor: The padded tensor with the specified shape.
"""
shape_array = self.backend.shape(z)
# When axis is provided, convert target_shape
if axis is not None:
target_shape = self._format_target_shape(shape_array, target_shape, axis)
if not fail_on_bigger_shape:
target_shape = [max(target_shape[i], shape_array[i]) for i in range(len(shape_array))]
# Compute the padding required for each dimension
pad_shape = np.array(target_shape) - shape_array
# Create the paddings array
if uniform:
# if odd, pad more on the left, same as:
# https://keras.io/api/layers/preprocessing_layers/image_preprocessing/center_crop/
right_pad = pad_shape // 2
left_pad = pad_shape - right_pad
paddings = np.stack([right_pad, left_pad], axis=1)
else:
paddings = np.stack([np.zeros_like(pad_shape), pad_shape], axis=1)
if np.any(paddings < 0):
raise ValueError(
f"Target shape {target_shape} must be greater than or equal "
f"to the input shape {shape_array}."
)
return self.backend.numpy.pad(z, paddings, **kwargs)
[docs]
def call(self, **kwargs):
data = kwargs[self.key]
padded_data = self.pad(
data,
self.target_shape,
self.uniform,
self.axis,
self.fail_on_bigger_shape,
**self.pad_kwargs,
)
return {self.output_key: padded_data}
[docs]
@ops_registry("threshold")
class Threshold(Operation):
"""Threshold an array, setting values below/above a threshold to a fill value."""
def __init__(
self,
threshold_type="hard",
below_threshold=True,
fill_value="min",
**kwargs,
):
super().__init__(**kwargs)
if threshold_type not in ("hard", "soft"):
raise ValueError("threshold_type must be 'hard' or 'soft'")
self.threshold_type = threshold_type
self.below_threshold = below_threshold
self.fill_value = fill_value
# Define threshold function at init
if threshold_type == "hard":
if below_threshold:
self._threshold_func = lambda data, threshold, fill: ops.where(
data < threshold, fill, data
)
else:
self._threshold_func = lambda data, threshold, fill: ops.where(
data > threshold, fill, data
)
else: # soft
if below_threshold:
self._threshold_func = lambda data, threshold, fill: (
ops.maximum(data - threshold, 0) + fill
)
else:
self._threshold_func = lambda data, threshold, fill: (
ops.minimum(data - threshold, 0) + fill
)
def _resolve_fill_value(self, data, threshold):
"""Get the fill value based on the fill_value_type."""
fv = self.fill_value
if isinstance(fv, (int, float)):
return ops.convert_to_tensor(fv, dtype=data.dtype)
elif fv == "min":
return ops.min(data)
elif fv == "max":
return ops.max(data)
elif fv == "threshold":
return threshold
else:
raise ValueError("Unknown fill_value")
[docs]
def call(
self,
threshold=None,
percentile=None,
**kwargs,
):
"""Threshold the input data.
Args:
threshold: Numeric threshold.
percentile: Percentile to derive threshold from.
Returns:
Tensor with thresholding applied.
"""
data = kwargs[self.key]
if (threshold is None) == (percentile is None):
raise ValueError("Pass either threshold or percentile, not both or neither.")
if percentile is not None:
# Convert percentile to quantile value (0-1 range)
threshold = ops.quantile(data, percentile / 100.0)
fill_value = self._resolve_fill_value(data, threshold)
result = self._threshold_func(data, threshold, fill_value)
return {self.output_key: result}