"""
Mask generation utilities.
These masks are used as a measurement operator for focused scan-line subsampling.
"""
from __future__ import annotations
from typing import List
import keras
from keras import ops
from zea.agent.gumbel import hard_straight_through
from zea.func.tensor import nonzero
_DEFAULT_DTYPE = "bool"
[docs]
def indices_to_k_hot(
indices,
n_possible_actions: int,
dtype=_DEFAULT_DTYPE,
):
"""Convert a list of indices to a k-hot encoded vector.
A k-hot encoded vector is suitable during tracing when the number of actions can change.
This is the default represenation for actions in zea.
Args:
indices (Tensor): Indices of selected actions to set to 1 of shape (..., n_actions).
n_possible_actions (int): Total number of possible actions.
dtype (str, optional): Data type of the mask. Defaults to "bool".
Returns:
Tensor: k-hot-encoded vector of shape (..., n_possible_actions).
"""
indices = ops.moveaxis(indices, -1, 0) # move n_actions to the front for one_hot
k_hot_encoded = ops.any(ops.one_hot(indices, n_possible_actions, dtype="bool"), axis=0)
# Cast to desired dtype, because ops.any will always return bool
return ops.cast(k_hot_encoded, dtype=dtype)
[docs]
def k_hot_to_indices(selected_lines, n_actions: int, fill_value=-1):
"""Convert k-hot encoded lines to indices of selected actions.
Args:
selected_lines (Tensor): k-hot encoded lines of shape (batch_size, n_possible_actions).
n_actions (int): Number of lines selected.
fill_value (int, optional): Value to fill in case there are not enough selected actions.
Defaults to -1.
Returns:
Tensor: Indices of selected actions of shape (batch_size, n_actions).
If there are fewer than `n_actions` selected, the remaining indices will be
filled with `fill_value`.
"""
# Find nonzero indices for each frame
def get_nonzero(row):
return nonzero(row > 0, size=n_actions, fill_value=fill_value)[0]
indices = ops.vectorized_map(get_nonzero, selected_lines)
return indices
def _assert_equal_spacing(n_actions, n_possible_actions):
assert n_possible_actions % n_actions == 0, (
"Number of actions must divide evenly into possible actions to use equispaced sampling. "
"If you do not care about equal spacing, set `assert_equal_spacing=False`."
)
[docs]
def initial_equispaced_lines(
n_actions, n_possible_actions, dtype=_DEFAULT_DTYPE, assert_equal_spacing=True
):
"""Generate an initial equispaced k-hot line mask.
For example, if ``n_actions=2`` and ``n_possible_actions=6``,
then ``initial_mask=[1, 0, 0, 1, 0, 0]``.
Args:
n_actions (int): Number of actions to be selected.
n_possible_actions (int): Number of possible actions.
dtype (str, optional): Data type of the mask. Defaults to "bool".
assert_equal_spacing (bool, optional): If True, asserts that
`n_possible_actions` is divisible by `n_actions`, this means that every
line will have the exact same spacing. Otherwise, there might be
some spacing differences. Defaults to True.
Returns:
Tensor: k-hot-encoded line vector of shape (n_possible_actions).
Needs to be converted to image size.
"""
assert n_actions > 0, "Number of actions must be > 0."
assert n_possible_actions > 0, "Number of possible actions must be > 0."
assert n_actions <= n_possible_actions, (
"Number of actions must be less than or equal to number of possible actions."
)
if assert_equal_spacing:
_assert_equal_spacing(n_actions, n_possible_actions)
# Distribute indices as evenly as possible
# This approach ensures spacing differs by at most 1 when not divisible
step = n_possible_actions / n_actions
selected_indices = ops.cast(
ops.round(ops.arange(0, n_actions, dtype="float32") * step), "int32"
)
return indices_to_k_hot(selected_indices, n_possible_actions, dtype=dtype)
[docs]
def next_equispaced_lines(previous_lines, shift=1):
"""
Rolls the previous equispaced mask of shape (..., n_possible_actions) to the right by
`shift` which is 1 by default.
"""
return ops.roll(previous_lines, shift=shift, axis=-1)
[docs]
def lines_to_im_size(lines, img_size: tuple):
"""
Convert k-hot-encoded line vectors to image size.
Args:
lines (Tensor): shape is (n_masks, n_possible_actions)
img_size (tuple): (height, width)
Returns:
Tensor: Masks of shape (n_masks, img_size, img_size)
"""
height, width = img_size
remainder = width % ops.shape(lines)[1]
assert remainder == 0, (
f"Width must be divisible by number of actions. Got remainder: {remainder}."
)
# Repeat till width of image
masks = ops.repeat(lines, width // ops.shape(lines)[1], axis=1)
# Repeat till height of image
masks = ops.repeat(masks[:, None], height, axis=1)
return masks
[docs]
def make_line_mask(
line_indices: List[int],
image_shape: List[int],
line_width: int = 1,
dtype=_DEFAULT_DTYPE,
):
"""
Creates a mask with vertical (i.e. second axis) lines at specified indices.
Args:
line_indices (List[int]): A list of indices where the lines should be drawn.
image_shape (List[int]): The shape of the image as [height, width, channels].
line_width (int, optional): The width of each line. Defaults to 1.
dtype (str, optional): The data type of the mask. Defaults to "bool".
Returns:
mask (Tensor): A tensor of the same shape as `image_shape` with lines drawn
at the specified indices.
"""
height, width, channels = image_shape
# Create k-hot vector for the line indices
k_hot = indices_to_k_hot(line_indices, width // line_width, dtype=dtype)
# Expand to (1, n_possible_actions) for lines_to_im_size
k_hot = ops.expand_dims(k_hot, axis=0)
# Use lines_to_im_size to create the mask of shape (1, height, width)
mask_2d = lines_to_im_size(k_hot, (height, width))[0]
# Expand to (height, width, channels)
return ops.broadcast_to(mask_2d[..., None], (height, width, channels))