Source code for zea.tracking.segmentation

"""Segmentation-based tracker using contour matching.

.. seealso::
    A tutorial notebook where this model is used:
    :doc:`../notebooks/models/speckle_tracking_example`.

"""

from keras import ops

from zea.func.tensor import find_contour

from .base import BaseTracker


[docs] class SegmentationTracker(BaseTracker): """Segmentation-based tracker. This tracker segments each frame independently and finds the closest points on the segmented contour to the previous frame's points. Args: model: Segmentation model with a `call` method. preprocess_fn: Optional preprocessing function to apply to frames before segmentation. postprocess_fn: Optional postprocessing function to apply to segmentation output, which should return a binary mask of the target structure. """ def __init__( self, model, preprocess_fn: callable = None, postprocess_fn: callable = None, ): """Initialize segmentation-based tracker.""" super().__init__(ndim=2) self.model = model self.preprocess_fn = preprocess_fn self.postprocess_fn = postprocess_fn if self.preprocess_fn is None: self.preprocess_fn = lambda frame: frame if self.postprocess_fn is None: raise ValueError("A postprocess_fn must be provided to extract binary masks.")
[docs] def track( self, prev_frame, # noqa F821 next_frame, points, ): """ Track points by segmenting next_frame and finding closest contour points. Args: prev_frame: Previous frame (not used, kept for interface compatibility). next_frame: Next frame to segment, shape (H, W). points: Points from previous frame, shape (N, 2) in (row, col) format. Returns: new_points: Closest points on next frame's contour, shape (N, 2). """ orig_shape = ops.shape(next_frame) frame_input = self.preprocess_fn(next_frame) outputs = self.model.call(frame_input) mask = self.postprocess_fn(outputs, orig_shape) contour_points = find_contour(mask) if ops.shape(contour_points)[0] > 0: new_points = self._find_closest_points(points, contour_points) else: new_points = points return new_points
def _find_closest_points(self, query_points, target_points): """Find closest target points to each query point. Args: query_points: Points to match, shape (N, 2). target_points: Points to match to, shape (M, 2). Returns: Closest target points, shape (N, 2). """ # Compute pairwise squared distances # query_points: (N, 2), target_points: (M, 2) # Expand dims: (N, 1, 2) and (1, M, 2) query_expanded = ops.expand_dims(query_points, axis=1) # (N, 1, 2) target_expanded = ops.expand_dims(target_points, axis=0) # (1, M, 2) # Compute squared distances: (N, M) diff = query_expanded - target_expanded sq_distances = ops.sum(diff * diff, axis=2) closest_indices = ops.argmin(sq_distances, axis=1) closest_points = ops.take(target_points, closest_indices, axis=0) return closest_points def __repr__(self): """String representation.""" return f"SegmentationTracker(model={self.model.__class__.__name__})"