Source code for zea.tracking.base
"""Base tracker class for point tracking algorithms."""
from abc import ABC, abstractmethod
from typing import List
from keras import ops
[docs]
class BaseTracker(ABC):
"""Abstract base class for point tracking algorithms.
This class defines the interface for tracking algorithms in the zea package.
Implementations should handle both 2D and 3D tracking where applicable.
Args:
ndim: Number of dimensions (2 for 2D, 3 for 3D).
**kwargs: Tracker-specific parameters.
"""
def __init__(self, ndim: int = 2, **kwargs):
"""Initialize the tracker with parameters."""
self.ndim = ndim
if self.ndim not in [2, 3]:
raise ValueError(f"Only 2D and 3D tracking supported, got {ndim}D")
[docs]
@abstractmethod
def track(
self,
prev_frame,
next_frame,
points,
):
"""
Track points from prev_frame to next_frame.
Args:
prev_frame: Previous frame/volume of shape (H, W) or (D, H, W).
next_frame: Next frame/volume of shape (H, W) or (D, H, W).
points: Points to track, shape (N, ndim) in (y, x) or (z, y, x) format.
Returns:
new_points: Tracked point locations, shape (N, ndim).
"""
pass
[docs]
def track_sequence(
self,
frames: List,
initial_points,
) -> List:
"""
Track points through a sequence of frames.
Args:
frames: List of frames/volumes to track through.
initial_points: Starting points in first frame, shape (N, ndim).
Returns:
List of N arrays, where each array has shape (T, ndim) containing
the trajectory of one point through all T frames.
"""
n_frames = len(frames)
n_points = int(ops.shape(initial_points)[0])
frames_t = [ops.convert_to_tensor(f, dtype="float32") for f in frames]
current_points = ops.convert_to_tensor(initial_points, dtype="float32")
trajectories = [ops.zeros((n_frames, self.ndim), dtype="float32") for _ in range(n_points)]
# Set initial positions
for i in range(n_points):
trajectories[i] = ops.scatter_update(
trajectories[i], [[0]], ops.expand_dims(current_points[i], 0)
)
# Track frame by frame
for t in range(n_frames - 1):
new_points = self.track(frames_t[t], frames_t[t + 1], current_points)
for i in range(n_points):
trajectories[i] = ops.scatter_update(
trajectories[i], [[t + 1]], ops.expand_dims(new_points[i], 0)
)
current_points = new_points
return trajectories
def __repr__(self):
"""String representation of the tracker."""
return f"{self.__class__.__name__}(ndim={self.ndim})"