"""Augmentation layers for ultrasound data."""
import keras
import numpy as np
from keras import layers, ops
from zea.func.tensor import is_jax_prng_key, split_seed
[docs]
class RandomCircleInclusion(layers.Layer):
"""
Adds a circular inclusion to the image, optionally at random locations.
Since this can accept N-dimensional inputs, you'll need to specify your
``circle_axes`` -- these are the axes onto which a circle will be drawn.
This circle will then be broadcast along the remaining dimensions.
You can then optionally specify whether there is a batch dim,
and whether the circles should be located randomly across that batch.
For example, if you have a batch of videos, e.g. of shape [batch, frame, height, width],
then you might want to specify ``circle_axes=(2, 3)``, and
``randomize_location_across_batch=True``. This would result in a circle that is located
in the same place per video, but different locations for different videos.
Once your method has recovered the circles, you can evaluate them using
the ``evaluate_recovered_circle_accuracy()`` method, which will expect an input
shape matching your inputs to ``call()``.
"""
def __init__(
self,
radius: int | tuple[int, int],
fill_value: float = 1.0,
circle_axes: tuple[int, int] = (1, 2),
with_batch_dim=True,
return_centers=False,
recovery_threshold=0.1,
randomize_location_across_batch=True,
seed=None,
width_range: tuple[int, int] = None,
height_range: tuple[int, int] = None,
**kwargs,
):
"""
Initialize RandomCircleInclusion.
Args:
radius (int or tuple[int, int]): Radius of the circle/ellipse to include.
fill_value (float): Value to fill inside the circle.
circle_axes (tuple[int, int]): Axes along which to draw the circle
(height, width).
with_batch_dim (bool): Whether input has a batch dimension.
return_centers (bool): Whether to return circle centers along with images.
recovery_threshold (float): Threshold for considering a pixel as recovered.
randomize_location_across_batch (bool): If True (and with_batch_dim=True),
each batch element gets a different random center. If False, all batch
elements share the same center.
seed (Any): Optional random seed for reproducibility.
width_range (tuple[int, int], optional): Range (min, max) for circle
center x (width axis).
height_range (tuple[int, int], optional): Range (min, max) for circle
center y (height axis).
**kwargs: Additional keyword arguments for the parent Layer.
Example:
.. doctest::
>>> from zea.data.augmentations import RandomCircleInclusion
>>> from keras import ops
>>> layer = RandomCircleInclusion(
... radius=5,
... circle_axes=(1, 2),
... with_batch_dim=True,
... )
>>> image = ops.zeros((1, 28, 28), dtype="float32")
>>> out = layer(image) # doctest: +SKIP
"""
super().__init__(**kwargs)
# Validate randomize_location_across_batch only makes sense with batch dim
if not with_batch_dim and not randomize_location_across_batch:
raise ValueError(
"randomize_location_across_batch=False is only applicable when "
"with_batch_dim=True. When with_batch_dim=False, there is no batch "
"to randomize across."
)
# Convert radius to tuple if int, else validate tuple
if isinstance(radius, int):
if radius <= 0:
raise ValueError(f"radius must be a positive integer, got {radius}.")
self.radius = (radius, radius)
elif isinstance(radius, tuple) and len(radius) == 2:
rx, ry = radius
if not all(isinstance(r, int) for r in (rx, ry)):
raise TypeError(f"radius tuple must contain two integers, got {radius!r}.")
if rx <= 0 or ry <= 0:
raise ValueError(f"radius components must be positive, got {radius!r}.")
self.radius = (rx, ry)
else:
raise TypeError("radius must be an int or a tuple of two ints")
self.fill_value = fill_value
self.circle_axes = circle_axes
self.with_batch_dim = with_batch_dim
self.return_centers = return_centers
self.recovery_threshold = recovery_threshold
self.randomize_location_across_batch = randomize_location_across_batch
self.seed = seed
self.width_range = width_range
self.height_range = height_range
self._axis1 = None
self._axis2 = None
self._perm = None
self._inv_perm = None
self._static_shape = None
self._static_batch = None
self._static_h = None
self._static_w = None
self._static_flat_batch = 1
[docs]
def build(self, input_shape):
"""
Build the layer and compute static shape and permutation info.
Args:
input_shape (tuple): Shape of the input tensor.
"""
rank = len(input_shape) - 1 if self.with_batch_dim else len(input_shape)
a1, a2 = self.circle_axes
if self.with_batch_dim and (a1 == 0 or a2 == 0):
raise ValueError("The circle axes should not be a batch dim")
if a1 < 0:
a1 += rank
elif a1 > 0 and self.with_batch_dim:
a1 -= 1
if a2 < 0:
a2 += rank
elif a2 > 0 and self.with_batch_dim:
a2 -= 1
if not (0 <= a1 < rank and 0 <= a2 < rank):
raise ValueError(f"circle_axes {self.circle_axes} out of range for rank {rank}")
if a1 == a2:
raise ValueError("circle_axes must be two distinct axes")
self._axis1, self._axis2 = a1, a2
all_axes = list(range(rank))
other_axes = [ax for ax in all_axes if ax not in (a1, a2)]
self._perm = other_axes + [a1, a2]
inv = [0] * rank
for i, ax in enumerate(self._perm):
inv[ax] = i
self._inv_perm = inv
if self.with_batch_dim:
input_shape = input_shape[1:] # ignore batch dim
permuted_shape = [input_shape[ax] for ax in self._perm]
if len(permuted_shape) > 2:
self._static_flat_batch = int(np.prod(permuted_shape[:-2]))
self._static_h = int(permuted_shape[-2])
self._static_w = int(permuted_shape[-1])
self._static_shape = tuple(permuted_shape)
# Validate that ellipse can fit within image bounds
rx, ry = self.radius
min_required_width = 2 * rx + 1
min_required_height = 2 * ry + 1
if self._static_w < min_required_width:
raise ValueError(
f"Image width ({self._static_w}) is too small for radius {rx}. "
f"Minimum required width: {min_required_width}"
)
if self._static_h < min_required_height:
raise ValueError(
f"Image height ({self._static_h}) is too small for radius {ry}. "
f"Minimum required height: {min_required_height}"
)
# Validate width_range and height_range if provided
if self.width_range is not None:
min_x, max_x = self.width_range
if min_x >= max_x:
raise ValueError(f"width_range must have min < max, got {self.width_range}")
if min_x < rx or max_x > self._static_w - rx:
raise ValueError(
f"width_range {self.width_range} would place circle outside image bounds. "
f"Valid range: [{rx}, {self._static_w - rx})"
)
if self.height_range is not None:
min_y, max_y = self.height_range
if min_y >= max_y:
raise ValueError(f"height_range must have min < max, got {self.height_range}")
if min_y < ry or max_y > self._static_h - ry:
raise ValueError(
f"height_range {self.height_range} would place circle outside image bounds. "
f"Valid range: [{ry}, {self._static_h - ry})"
)
super().build(input_shape)
[docs]
def compute_output_shape(self, input_shape):
"""
Compute output shape for the layer.
Args:
input_shape (tuple): Shape of the input tensor.
Returns:
tuple: The output shape (same as input).
"""
return input_shape
def _permute_axes_to_circle_last(self, x):
"""
Permute axes so that circle axes are last.
Args:
x (Tensor): Input tensor.
Returns:
Tensor: Tensor with circle axes as the last two dimensions.
"""
return ops.transpose(x, axes=self._perm)
def _flatten_batch_and_other_dims(self, x):
"""
Flatten all axes except the last two (circle axes).
Args:
x (Tensor): Input tensor with circle axes last.
Returns:
tuple: (reshaped tensor, flat batch size, height, width).
"""
shape = x.shape
flat_batch = int(np.prod(shape[:-2])) if len(shape) > 2 else 1
h, w = shape[-2], shape[-1]
return ops.reshape(x, [flat_batch, h, w]), flat_batch, h, w
def _make_circle_mask(self, centers, h, w, radius, dtype):
"""
Create a mask for each center (batch, h, w) using Keras ops.
Args:
centers (Tensor): Tensor of shape (batch, 2) with circle centers.
h (int): Height of the image.
w (int): Width of the image.
radius (tuple[int, int]): Radii of the ellipse (rx, ry).
dtype (str or dtype): Data type for the mask.
Returns:
Tensor: Mask of shape (batch, h, w).
"""
Y = ops.arange(h)
X = ops.arange(w)
Y, X = ops.meshgrid(Y, X, indexing="ij")
Y = ops.expand_dims(Y, 0) # (1, h, w)
X = ops.expand_dims(X, 0) # (1, h, w)
cx = centers[:, 0][:, None, None]
cy = centers[:, 1][:, None, None]
rx, ry = radius
# Ellipse equation: ((X-cx)/rx)^2 + ((Y-cy)/ry)^2 <= 1
dist = ((X - cx) / rx) ** 2 + ((Y - cy) / ry) ** 2
mask = ops.cast(dist <= 1, dtype)
return mask
[docs]
def call(self, x, seed=None):
"""
Apply the random circle inclusion augmentation.
Args:
x (Tensor): Input tensor.
seed (Any, optional): Optional random seed for reproducibility.
Returns:
Tensor or tuple: Augmented images, and optionally the circle
centers if return_centers is True.
"""
if keras.backend.backend() == "jax" and not is_jax_prng_key(seed):
if isinstance(seed, keras.random.SeedGenerator):
raise ValueError(
"When using JAX backend, please provide a jax.random.PRNGKey as seed, "
"instead of keras.random.SeedGenerator."
)
else:
raise TypeError(
f"When using JAX backend, seed must be a JAX PRNG key (created with "
f"jax.random.PRNGKey()), but got {type(seed)}. Note: jax.random.key() "
f"keys are not currently supported."
)
seed = seed if seed is not None else self.seed
if self.with_batch_dim:
x_is_symbolic_tensor = not isinstance(ops.shape(x)[0], int)
if x_is_symbolic_tensor:
if self.randomize_location_across_batch:
imgs, centers = ops.map(lambda arg: self._call(arg, seed), x)
else:
raise NotImplementedError(
"You cannot fix circle locations across batch while using "
+ "RandomCircleInclusion as a dataset augmentation, "
+ "since samples in a batch are handled independently."
)
else:
batch_size = ops.shape(x)[0]
if self.randomize_location_across_batch:
seeds = split_seed(seed, batch_size)
if all(s is seeds[0] for s in seeds):
imgs, centers = ops.map(lambda arg: self._call(arg, seeds[0]), x)
else:
imgs, centers = ops.map(
lambda args: self._call(args[0], args[1]), (x, seeds)
)
else:
# Generate one random center that will be used for all batch elements
img0, center0 = self._call(x[0], seed)
# Apply the same center to all batch elements
imgs_list, centers_list = [img0], [center0]
for i in range(1, batch_size):
img_aug, center_out = self._call_with_fixed_center(x[i], center0)
imgs_list.append(img_aug)
centers_list.append(center_out)
imgs = ops.stack(imgs_list, axis=0)
centers = ops.stack(centers_list, axis=0)
else:
imgs, centers = self._call(x, seed)
if self.return_centers:
return imgs, centers
else:
return imgs
def _call(self, x, seed):
"""
Internal method to apply the augmentation to a single image.
Args:
x (Tensor): Input image tensor with circle axes last.
seed (Any): Random seed for circle location.
Returns:
tuple: (augmented image, center coordinates).
"""
x = self._permute_axes_to_circle_last(x)
flat, flat_batch_size, h, w = self._flatten_batch_and_other_dims(x)
def _draw_circle_2d(img2d):
rx, ry = self.radius
# Determine allowed ranges for center
if self.width_range is not None:
min_x, max_x = self.width_range
else:
min_x, max_x = rx, w - rx
if self.height_range is not None:
min_y, max_y = self.height_range
else:
min_y, max_y = ry, h - ry
# Ensure the ellipse fits within the allowed region
cx = ops.cast(
keras.random.uniform((), min_x, max_x, seed=seed),
"int32",
)
new_seed, _ = split_seed(seed, 2) # ensure that cx and cy are independent
cy = ops.cast(
keras.random.uniform((), min_y, max_y, seed=new_seed),
"int32",
)
mask = self._make_circle_mask(
ops.stack([cx, cy])[None, :], h, w, (rx, ry), img2d.dtype
)[0]
img_aug = img2d * (1 - mask) + self.fill_value * mask
center = ops.stack([cx, cy])
return img_aug, center
aug_imgs, centers = ops.vectorized_map(_draw_circle_2d, flat)
aug_imgs = ops.reshape(aug_imgs, x.shape)
aug_imgs = ops.transpose(aug_imgs, axes=self._inv_perm)
centers_shape = [2] if flat_batch_size == 1 else [flat_batch_size, 2]
centers = ops.reshape(centers, centers_shape)
return (aug_imgs, centers)
def _apply_circle_mask(self, flat, center, h, w):
"""Apply circle mask to flattened image data.
Args:
flat (Tensor): Flattened image data of shape (flat_batch, h, w).
center (Tensor): Center coordinates, either (2,) or (flat_batch, 2).
h (int): Height of images.
w (int): Width of images.
Returns:
Tensor: Augmented images with circle applied.
"""
rx, ry = self.radius
# Ensure center has batch dimension for broadcasting
if len(center.shape) == 1:
# Single center (2,) -> broadcast to all slices
center_batched = ops.tile(ops.reshape(center, [1, 2]), [flat.shape[0], 1])
else:
# Already batched (flat_batch, 2)
center_batched = center
# Create masks for all slices using vectorized_map or broadcasting
masks = self._make_circle_mask(center_batched, h, w, (rx, ry), flat.dtype)
# Apply masks
aug_imgs = flat * (1 - masks) + self.fill_value * masks
return aug_imgs
def _call_with_fixed_center(self, x, fixed_center):
"""Apply augmentation using a pre-determined center.
Args:
x (Tensor): Input image tensor.
fixed_center (Tensor): Pre-determined center coordinates, either (2,)
for a single center or (flat_batch, 2) for per-slice centers.
Returns:
tuple: (augmented image, center coordinates).
"""
x = self._permute_axes_to_circle_last(x)
flat, flat_batch_size, h, w = self._flatten_batch_and_other_dims(x)
# Apply circle mask with fixed center (handles both single and batched centers)
aug_imgs = self._apply_circle_mask(flat, fixed_center, h, w)
aug_imgs = ops.reshape(aug_imgs, x.shape)
aug_imgs = ops.transpose(aug_imgs, axes=self._inv_perm)
# Return centers matching the expected shape
if len(fixed_center.shape) == 1:
# Single center (2,) -> broadcast to match flat_batch_size
if flat_batch_size == 1:
centers = fixed_center
else:
centers = ops.tile(ops.reshape(fixed_center, [1, 2]), [flat_batch_size, 1])
else:
# Already batched centers (flat_batch, 2)
centers = fixed_center
return (aug_imgs, centers)
[docs]
def get_config(self):
"""
Get layer configuration for serialization.
Returns:
dict: Dictionary of layer configuration.
"""
cfg = super().get_config()
cfg.update(
{
"radius": self.radius,
"fill_value": self.fill_value,
"circle_axes": self.circle_axes,
"return_centers": self.return_centers,
"width_range": self.width_range,
"height_range": self.height_range,
}
)
return cfg
[docs]
def evaluate_recovered_circle_accuracy(
self, images, centers, recovery_threshold, fill_value=None
):
"""
Evaluate the percentage of the true circle that has been recovered in the images,
and return a mask of the detected part of the circle.
Args:
images (Tensor): Tensor of images (any shape, with circle axes as specified).
centers (Tensor): Tensor of circle centers (matching batch size).
recovery_threshold (float): Threshold for considering a pixel as recovered.
fill_value (float, optional): Optionally override fill_value for cases
where image range has changed.
Returns:
Tuple[Tensor, Tensor]:
- percent_recovered: [batch] - average recovery percentage per batch element,
averaged across all non-batch dimensions (e.g., frames, slices)
- recovered_masks: [batch, flat_batch, h, w] or [batch, h, w] or [flat_batch, h, w]
depending on input shape - binary mask of detected circle regions
"""
fill_value = fill_value or self.fill_value
def _evaluate_recovered_circle_accuracy(image, center):
image_perm = self._permute_axes_to_circle_last(image)
h, w = image_perm.shape[-2], image_perm.shape[-1]
flat_image, _, _, _ = self._flatten_batch_and_other_dims(image_perm)
flat_center = ops.reshape(center, [-1, 2])
mask = self._make_circle_mask(flat_center, h, w, self.radius, flat_image.dtype)
diff = ops.abs(flat_image - fill_value)
recovered = ops.cast(diff <= recovery_threshold, flat_image.dtype) * mask
recovered_sum = ops.sum(recovered, axis=[1, 2])
mask_sum = ops.sum(mask, axis=[1, 2])
percent_recovered = recovered_sum / (mask_sum + 1e-8)
# recovered_mask: binary mask of detected part of the circle
recovered_mask = ops.cast(recovered > 0, flat_image.dtype)
return percent_recovered, recovered_mask
if self.with_batch_dim:
results = ops.vectorized_map(
lambda args: _evaluate_recovered_circle_accuracy(args[0], args[1]),
(images, centers),
)
percent_recovered, recovered_masks = results
# If there are multiple circles per batch element (e.g., multiple frames/slices),
# take the mean across all non-batch dimensions to get one value per batch element
if len(percent_recovered.shape) > 1:
# Average over all axes except the batch dimension (axis 0)
axes_to_reduce = tuple(range(1, len(percent_recovered.shape)))
percent_recovered = ops.mean(percent_recovered, axis=axes_to_reduce)
return percent_recovered, recovered_masks
else:
percent_recovered, recovered_mask = _evaluate_recovered_circle_accuracy(images, centers)
return percent_recovered, recovered_mask