zea.agent.selection¶
Action selection strategies¶
These selection strategies implement a variety of policies for choosing which focused
transmit to fire next, potentially given some beliefs about the state of tissue,
as represented by particles.
For a comprehensive example usage, see: Active perception for focused transmit steering
All strategies are stateless, meaning that they do not maintain any internal state.
Classes
|
Covariance sampling action selection. |
|
Equispaced lines action selection. |
|
Greedy entropy action selection. |
|
Base class for action selection methods that select lines. |
Base class for any action selection method that does masking. |
|
|
Task-based line selection for maximizing information gain. |
|
Uniform random lines action selection. |
- class zea.agent.selection.CovarianceSamplingLines(n_actions, n_possible_actions, img_width, img_height, seed=42, n_masks=200)[source]¶
Bases:
LinesActionModelCovariance sampling action selection.
This class models the line-to-line correlation to select the mask with the highest entropy.
Initialize the CovarianceSamplingLines action selection model.
- Parameters:
n_actions (
int) – The number of actions the agent can take.n_possible_actions (
int) – The number of possible actions.img_width (
int) – The width of the input image.img_height (
int) – The height of the input image.seed (
int) – The seed for random number generation. Defaults to 42.n_masks (
int) – The number of masks. Defaults to 200.
- Raises:
AssertionError – If image width is not divisible by n_possible_actions.
- random_uniform_lines(batch_size, seed=None)[source]¶
Wrapper around random_uniform_lines function to use attributes from class.
- sample(particles, seed=None)[source]¶
Sample the action using the covariance sampling method.
- Parameters:
particles (Tensor) – Particles of shape (batch_size, n_particles, h, w)
seed (int | SeedGenerator | jax.random.key, optional) – Seed for random number generation. Defaults to None.
- Returns:
Newly selected lines as k-hot vectors, shaped (batch_size, n_possible_actions)
Masks of shape (batch_size, img_height, img_width)
- Return type:
Tuple[Tensor, Tensor]
- class zea.agent.selection.EquispacedLines(n_actions, n_possible_actions, img_width, img_height, assert_equal_spacing=True)[source]¶
Bases:
LinesActionModelEquispaced lines action selection.
Creates masks with equispaced lines that sweep across the image.
Initialize the LinesActionModel.
- Parameters:
n_actions (
int) – The number of actions the agent can take.n_possible_actions (
int) – The number of possible actions.img_width (
int) – The width of the input image.img_height (
int) – The height of the input image.
- stack_n_cols¶
The number of columns in the image that correspond to a single action.
- Type:
int
- initial_sample_stateless(batch_size=1)[source]¶
Initial sample stateless.
Generates a batch of initial equispaced line masks.
- Returns:
Selected lines as k-hot vectors, shaped (batch_size, n_possible_actions)
Masks of shape (batch_size, img_height, img_width)
- Return type:
Tuple[Tensor, Tensor]
- sample(current_lines=None, batch_size=1)[source]¶
Sample the action using the equispaced method.
Generates or updates an equispaced mask to sweep rightwards by one step across the image.
- Returns:
The mask of shape (batch_size, img_size, img_size)
- Return type:
Tensor
- sample_stateless(current_lines)[source]¶
Sample stateless.
Updates an existing equispaced mask to sweep rightwards by one step across the image.
- Parameters:
current_lines – Currently selected lines as k-hot vectors, shaped (batch_size, n_possible_actions)
- Returns:
Newly selected lines as k-hot vectors, shaped (batch_size, n_possible_actions)
Masks of shape (batch_size, img_height, img_width)
- Return type:
Tuple[Tensor, Tensor]
- class zea.agent.selection.GreedyEntropy(n_actions, n_possible_actions, img_width, img_height, mean=0, std_dev=1, num_lines_to_update=5, entropy_sigma=1.0, average_entropy_across_batch=False)[source]¶
Bases:
LinesActionModelGreedy entropy action selection.
Selects the max entropy line and reweights the entropy values around it, approximating the decrease in entropy that would occur from observing that line.
The neighbouring values are decreased by a Gaussian function centered at the selected line.
Initialize the GreedyEntropy action selection model.
- Parameters:
n_actions (
int) – The number of actions the agent can take.n_possible_actions (
int) – The number of possible actions.img_width (
int) – The width of the input image.img_height (
int) – The height of the input image.mean (
float) – The mean of the RBF. Defaults to 0.std_dev (
float) – The standard deviation of the RBF. Defaults to 1.num_lines_to_update (
int) – The number of lines around the selected line to update. Must be odd.entropy_sigma (
float) – The standard deviation of the Gaussian Mixture components used to approximate the posterior.average_entropy_across_batch (
bool) – Whether to average entropy across the batch when selecting lines. This can be useful when selecting planes in 3D imaging, where the batch dimension represents a third spatial dimension. Defaults to False.
- static compute_pairwise_pixel_gaussian_error(particles, entropy_sigma=1.0)[source]¶
Compute the pairwise pixelwise Gaussian error.
This function computes the Gaussian error between each pair of pixels in the set of particles provided. This can be used to approximate the entropy of a Gaussian mixture model, where the particles are the means of the Gaussians. For more details see Section 4 here: https://arxiv.org/abs/2406.14388
- Parameters:
particles (Tensor) – Particles of shape (batch_size, n_particles, *pixels)
entropy_sigma (float, optional) – The standard deviation of the Gaussian Mixture components used to approximate the posterior. Defaults to 1.0.
- Returns:
batch of pixelwise pairwise Gaussian errors, of shape (batch_size, n_particles, n_particles, *pixels)
- Return type:
Tensor
- static compute_pixelwise_entropy(particles, entropy_sigma=1.0)[source]¶
This function computes the entropy for each pixel using a Gaussian Mixture Model approximation of the posterior distribution. For more details see Section VI. B here: https://arxiv.org/pdf/2410.13310
- Parameters:
particles (Tensor) – Particles of shape (batch_size, n_particles, *pixels)
entropy_sigma (float, optional) – The standard deviation of the Gaussian Mixture components used to approximate the posterior. Defaults to 1.0.
- Returns:
batch of entropies per pixel, of shape (batch_size, *pixels)
- Return type:
Tensor
- reweight_entropies_around_line(entropy_per_line, line_index)[source]¶
Reweight the entropy around a selected line.
This approximates the decrease in entropy that would occur from observing that line. It works by multiplying the entropy values around the selected line by an upside-down Gaussian function centered at the selected line, setting the entropy of the selected line to 0, and decreasing the entropies of neighbouring lines.
Note
This function is not compatible with the torch backend. See Issue #268
- Parameters:
entropy_per_line (Tensor) – Entropy per line of shape (n_possible_actions,)
line_index (int) – Index of the line with maximum entropy
- Returns:
The reweighted entropy per line, of shape (n_possible_actions,)
- Return type:
Tuple
- sample(particles)[source]¶
Sample the action using the greedy entropy method.
- Parameters:
particles (Tensor) – Particles of shape (batch_size, n_particles, height, width)
- Returns:
Newly selected lines as k-hot vectors, shaped (batch_size, n_possible_actions)
Masks of shape (batch_size, img_height, img_width)
- Return type:
Tuple[Tensor, Tensor]
- class zea.agent.selection.LinesActionModel(n_actions, n_possible_actions, img_width, img_height)[source]¶
Bases:
MaskActionModelBase class for action selection methods that select lines.
Initialize the LinesActionModel.
- Parameters:
n_actions (
int) – The number of actions the agent can take.n_possible_actions (
int) – The number of possible actions.img_width (
int) – The width of the input image.img_height (
int) – The height of the input image.
- stack_n_cols¶
The number of columns in the image that correspond to a single action.
- Type:
int
- class zea.agent.selection.MaskActionModel[source]¶
Bases:
objectBase class for any action selection method that does masking.
- class zea.agent.selection.TaskBasedLines(n_actions, n_possible_actions, img_width, img_height, downstream_task_function, mean=0, std_dev=1, num_lines_to_update=5, **kwargs)[source]¶
Bases:
GreedyEntropyTask-based line selection for maximizing information gain.
This action selection strategy chooses lines to maximize information gain with respect to a downstream task outcome. It uses gradient-based saliency to identify which image regions contribute most to task uncertainty, then selects lines accordingly.
Initialize the TaskBasedLines action selection model.
- Parameters:
n_actions (
int) – The number of actions the agent can take.n_possible_actions (
int) – The number of possible actions (line positions).img_width (
int) – The width of the input image.img_height (
int) – The height of the input image.downstream_task_function (
Callable) – A differentiable function that takes a batch of inputs and produces scalar outputs. This represents the downstream task for which information gain should be maximized.mean (
float) – The mean of the RBF used for reweighting. Defaults to 0.std_dev (
float) – The standard deviation of the RBF used for reweighting. Defaults to 1.num_lines_to_update (
int) – The number of lines around the selected line to update during reweighting. Must be odd. Defaults to 5.**kwargs – Additional keyword arguments passed to the parent class.
- compute_output_and_saliency_propagation(particles)[source]¶
Compute saliency-weighted posterior variance for task-based selection.
This method computes how much each pixel contributes to the variance of the downstream task output. It uses automatic differentiation to compute gradients of the task function with respect to each particle, then weights the posterior variance by the squared mean gradient.
- Parameters:
particles (Tensor) – Particles of shape (batch_size, n_particles, height, width) representing the posterior distribution over images.
- Returns:
- Pixelwise contribution to downstream task variance,
of shape (batch_size, height, width). Higher values indicate pixels that contribute more to task uncertainty.
- Return type:
Tensor
- sample(particles)[source]¶
Sample actions using task-based information gain maximization.
This method computes which lines would provide the most information about the downstream task by: 1. Computing pixelwise contribution to task variance using gradients 2. Aggregating contributions into line-wise scores 3. Greedily selecting lines with highest contribution scores 4. Reweighting scores around selected lines (inherited from GreedyEntropy)
- Parameters:
particles (Tensor) – Particles representing the posterior distribution, of shape (batch_size, n_particles, height, width).
- Returns:
selected_lines_k_hot: Selected lines as k-hot vectors, shaped (batch_size, n_possible_actions)
masks: Binary masks of shape (batch_size, img_height, img_width)
pixelwise_contribution_to_var_dst: Pixelwise contribution to downstream task variance, of shape (batch_size, height, width)
- Return type:
Tuple[Tensor, Tensor, Tensor]
Note
Unlike the parent GreedyEntropy class, this method returns an additional tensor containing the pixelwise contribution scores for analysis.
- sum_neighbouring_columns_into_n_possible_actions(full_linewise_salience)[source]¶
Aggregate column-wise saliency into line-wise saliency scores.
This method groups neighboring columns together to create saliency scores for each possible line action. Since each line action may correspond to multiple image columns, this aggregation is necessary to match the action space.
- Parameters:
full_linewise_salience (Tensor) – Saliency values for each column, of shape (batch_size, full_image_width).
- Returns:
- Aggregated saliency scores for each possible action,
of shape (batch_size, n_possible_actions).
- Return type:
Tensor
- Raises:
AssertionError – If the image width is not evenly divisible by n_possible_actions.
- class zea.agent.selection.UniformRandomLines(n_actions, n_possible_actions, img_width, img_height)[source]¶
Bases:
LinesActionModelUniform random lines action selection.
Creates masks with uniformly randomly sampled lines.
Initialize the LinesActionModel.
- Parameters:
n_actions (
int) – The number of actions the agent can take.n_possible_actions (
int) – The number of possible actions.img_width (
int) – The width of the input image.img_height (
int) – The height of the input image.
- stack_n_cols¶
The number of columns in the image that correspond to a single action.
- Type:
int
- sample(batch_size=1, seed=None)[source]¶
Sample the action using the uniform random method.
Generates or updates an equispaced mask to sweep rightwards by one step across the image.
- Parameters:
seed (int | SeedGenerator | jax.random.key, optional) – Seed for random number generation. Defaults to None.
- Returns:
Newly selected lines as k-hot vectors, shaped (batch_size, n_possible_actions)
Masks of shape (batch_size, img_height, img_width)
- Return type:
Tuple[Tensor, Tensor]