"""Lucas-Kanade optical flow tracker.
.. seealso::
A tutorial notebook where this model is used:
:doc:`../notebooks/models/speckle_tracking_example`.
"""
from typing import Tuple
from keras import ops
from zea.func.tensor import gaussian_filter, translate
from .base import BaseTracker
[docs]
class LucasKanadeTracker(BaseTracker):
"""Lucas-Kanade optical flow tracker.
Implements pyramidal Lucas-Kanade optical flow tracking.
Args:
win_size: Window size (height, width) for 2D or (depth, height, width) for 3D.
max_level: Number of pyramid levels (0 means no pyramid).
max_iterations: Maximum iterations per pyramid level.
epsilon: Convergence threshold for iterative solver.
**kwargs: Additional parameters.
Example:
.. doctest::
>>> from zea.tracking import LucasKanadeTracker
>>> import numpy as np
>>> tracker = LucasKanadeTracker(win_size=(32, 32), max_level=3)
>>> frame1 = np.random.rand(100, 100).astype("float32")
>>> frame2 = np.random.rand(100, 100).astype("float32")
>>> points = np.array([[50.5, 55.2], [60.1, 65.8]], dtype="float32")
>>> new_points = tracker.track(frame1, frame2, points)
>>> new_points.shape
(2, 2)
"""
def __init__(
self,
win_size: Tuple[int, ...] = (32, 32),
max_level: int = 3,
max_iterations: int = 30,
epsilon: float = 0.01,
**kwargs,
):
"""Initialize custom Lucas-Kanade tracker."""
self.ndim = len(win_size)
super().__init__(ndim=self.ndim, **kwargs)
self.win_size = win_size
self.max_level = max_level
self.max_iterations = max_iterations
self.epsilon = epsilon
self.half_win = tuple(w // 2 for w in win_size)
[docs]
def track(
self,
prev_frame,
next_frame,
points,
) -> Tuple:
"""
Track points using custom pyramidal Lucas-Kanade.
Args:
prev_frame: Previous frame/volume (tensor), shape (H, W) for 2D or (D, H, W) for 3D.
next_frame: Next frame/volume (tensor), shape (H, W) for 2D or (D, H, W) for 3D.
points: Points to track (tensor), shape (N, ndim) in (y, x) or (z, y, x) format.
Returns:
new_points: Tracked points as tensor, shape (N, ndim).
"""
if self.ndim not in [2, 3]:
raise NotImplementedError(f"Only 2D and 3D tracking supported, got {self.ndim}D")
# Normalize frames to [0, 1]
prev_norm = translate(prev_frame, range_to=(0, 1))
next_norm = translate(next_frame, range_to=(0, 1))
# Build pyramids
if self.max_level > 0:
prev_pyr = self._build_pyramid(prev_norm, self.max_level + 1)
next_pyr = self._build_pyramid(next_norm, self.max_level + 1)
else:
prev_pyr = [prev_norm]
next_pyr = [next_norm]
n_levels = len(prev_pyr)
n_points = int(points.shape[0])
# Start at coarsest level
scale = 2 ** (n_levels - 1)
curr_points = points / scale
flows = ops.zeros((n_points, self.ndim), dtype="float32")
# Track through pyramid levels
for level in range(n_levels):
prev_img = prev_pyr[level]
next_img = next_pyr[level]
# Track each point
new_flows = []
for i in range(n_points):
pt = curr_points[i]
flow_guess = flows[i]
flow = self._track_point(prev_img, next_img, pt, flow_guess)
new_flows.append(flow)
flows = ops.stack(new_flows)
# Scale for next level (if not at finest)
if level < n_levels - 1:
flows = flows * 2.0
curr_points = curr_points * 2.0
# Final points at full resolution
new_points = points + flows
return new_points
def _build_pyramid(self, image, n_levels: int) -> list:
"""Build Gaussian pyramid."""
pyramid = [image]
for _ in range(1, n_levels):
curr = pyramid[-1]
shape = ops.shape(curr)
# Check minimum size based on dimensionality
if self.ndim == 2:
h, w = shape[0], shape[1]
min_size = ops.minimum(h, w)
if min_size < 4:
break
else: # 3D
d, h, w = shape[0], shape[1], shape[2]
min_size = ops.minimum(ops.minimum(d, h), w)
if min_size < 4:
break
blurred = gaussian_filter(curr, sigma=0.849, mode="reflect")
# Downsample by 2x using map_coordinates
if self.ndim == 2:
new_h, new_w = h // 2, w // 2
# Create downsampled coordinate grid
y_coords = ops.linspace(0, h - 1, new_h)
x_coords = ops.linspace(0, w - 1, new_w)
grid_y, grid_x = ops.meshgrid(y_coords, x_coords, indexing="ij")
coords = ops.stack([grid_y, grid_x], axis=0)
downsampled = ops.image.map_coordinates(blurred, coords, order=1)
else: # 3D
new_d, new_h, new_w = d // 2, h // 2, w // 2
# Create downsampled coordinate grid
z_coords = ops.linspace(0, d - 1, new_d)
y_coords = ops.linspace(0, h - 1, new_h)
x_coords = ops.linspace(0, w - 1, new_w)
grid_z, grid_y, grid_x = ops.meshgrid(z_coords, y_coords, x_coords, indexing="ij")
coords = ops.stack([grid_z, grid_y, grid_x], axis=0)
downsampled = ops.image.map_coordinates(blurred, coords, order=1)
pyramid.append(downsampled)
return pyramid[::-1]
def _track_point(
self,
prev_img,
next_img,
point,
flow_guess,
):
"""Track a single point using iterative Lucas-Kanade."""
# Extract template window
template, valid_template = self._extract_window(prev_img, point)
if not valid_template:
return flow_guess
# Compute template gradients (Sobel) - returns tensors
gradients = self._sobel_gradients(template)
# Flatten gradients for 2D or 3D
if self.ndim == 2:
Iy, Ix = gradients
Ix_flat = ops.reshape(Ix, [-1])
Iy_flat = ops.reshape(Iy, [-1])
# Structure tensor 2D components
IxIx = ops.sum(Ix_flat * Ix_flat)
IxIy = ops.sum(Ix_flat * Iy_flat)
IyIy = ops.sum(Iy_flat * Iy_flat)
else: # 3D
Iz, Iy, Ix = gradients
Ix_flat = ops.reshape(Ix, [-1])
Iy_flat = ops.reshape(Iy, [-1])
Iz_flat = ops.reshape(Iz, [-1])
# Structure tensor 3D components
IxIx = ops.sum(Ix_flat * Ix_flat)
IxIy = ops.sum(Ix_flat * Iy_flat)
IxIz = ops.sum(Ix_flat * Iz_flat)
IyIy = ops.sum(Iy_flat * Iy_flat)
IyIz = ops.sum(Iy_flat * Iz_flat)
IzIz = ops.sum(Iz_flat * Iz_flat)
# Iterative refinement (keep as tensors)
flow = flow_guess
for iteration in range(self.max_iterations):
# Extract warped window from next image
warped_pt = point + flow
warped, valid_warped = self._extract_window(next_img, warped_pt)
if not valid_warped:
break
# Image difference
diff = template - warped
diff_flat = ops.reshape(diff, [-1])
# Solve for flow update
if self.ndim == 2:
# Build structure tensor matrix (2x2)
structure = ops.stack(
[
ops.stack([IxIx, IxIy]),
ops.stack([IxIy, IyIy]),
],
axis=0,
)
# Add regularization to diagonal
structure = structure + ops.eye(2, dtype=structure.dtype) * 1e-5
# Right-hand side vector
b_x = ops.sum(Ix_flat * diff_flat)
b_y = ops.sum(Iy_flat * diff_flat)
rhs = ops.reshape(ops.stack([b_x, b_y]), (2, 1))
# Solve: structure * delta_xy = rhs
delta_xy = ops.matmul(ops.linalg.inv(structure), rhs)
delta_xy = ops.reshape(delta_xy, (2,))
# Reorder to (y, x)
delta = ops.stack([delta_xy[1], delta_xy[0]])
else: # 3D
# Build structure tensor matrix (3x3)
structure = ops.stack(
[
ops.stack([IxIx, IxIy, IxIz]),
ops.stack([IxIy, IyIy, IyIz]),
ops.stack([IxIz, IyIz, IzIz]),
],
axis=0,
)
# Add regularization to diagonal
structure = structure + ops.eye(3, dtype=structure.dtype) * 1e-5
# Right-hand side vector
b_x = ops.sum(Ix_flat * diff_flat)
b_y = ops.sum(Iy_flat * diff_flat)
b_z = ops.sum(Iz_flat * diff_flat)
rhs = ops.reshape(ops.stack([b_x, b_y, b_z]), (3, 1))
# Solve: structure * delta_xyz = rhs
delta_xyz = ops.matmul(ops.linalg.inv(structure), rhs)
delta_xyz = ops.reshape(delta_xyz, (3,))
# Reorder to (z, y, x)
delta = ops.stack([delta_xyz[2], delta_xyz[1], delta_xyz[0]])
# Update flow
flow = flow + delta
# Check convergence
delta_norm = ops.sqrt(ops.sum(delta * delta))
if delta_norm < self.epsilon:
break
return flow
def _extract_window(self, image, point):
"""Extract window around point with subpixel interpolation."""
if self.ndim == 2:
return self._extract_window_2d(image, point)
elif self.ndim == 3:
return self._extract_window_3d(image, point)
else:
raise ValueError(f"Unsupported ndim: {self.ndim}")
def _extract_window_2d(self, image, point):
"""Extract 2D window with bilinear interpolation using map_coordinates."""
hy, hx = self.half_win
h, w = ops.shape(image)[0], ops.shape(image)[1]
py, px = point[0], point[1]
# Bounds check
if ops.any(
ops.stack(
[
py < hy + 1,
py >= ops.cast(h, py.dtype) - hy - 1,
px < hx + 1,
px >= ops.cast(w, px.dtype) - hx - 1,
]
)
):
return ops.zeros((2 * hy + 1, 2 * hx + 1), dtype="float32"), False
# Create coordinate grid for the window
# Grid centered at point location
y_coords = ops.arange(2 * hy + 1, dtype="float32") + py - hy
x_coords = ops.arange(2 * hx + 1, dtype="float32") + px - hx
grid_y, grid_x = ops.meshgrid(y_coords, x_coords, indexing="ij")
# Stack coordinates for map_coordinates
coords = ops.stack([grid_y, grid_x], axis=0)
# Extract window using bilinear interpolation
window = ops.image.map_coordinates(image, coords, order=1)
return window, True
def _extract_window_3d(self, image, point):
"""Extract 3D window with trilinear interpolation using map_coordinates."""
hz, hy, hx = self.half_win
d, h, w = ops.shape(image)[0], ops.shape(image)[1], ops.shape(image)[2]
pz, py, px = point[0], point[1], point[2]
# Bounds check
if ops.any(
ops.stack(
[
pz < hz + 1,
pz >= ops.cast(d, pz.dtype) - hz - 1,
py < hy + 1,
py >= ops.cast(h, py.dtype) - hy - 1,
px < hx + 1,
px >= ops.cast(w, px.dtype) - hx - 1,
]
)
):
return ops.zeros((2 * hz + 1, 2 * hy + 1, 2 * hx + 1), dtype="float32"), False
# Create coordinate grid for the window
# Grid centered at point location
z_coords = ops.arange(2 * hz + 1, dtype="float32") + pz - hz
y_coords = ops.arange(2 * hy + 1, dtype="float32") + py - hy
x_coords = ops.arange(2 * hx + 1, dtype="float32") + px - hx
grid_z, grid_y, grid_x = ops.meshgrid(z_coords, y_coords, x_coords, indexing="ij")
# Stack coordinates for map_coordinates
coords = ops.stack([grid_z, grid_y, grid_x], axis=0)
# Extract window using trilinear interpolation
window = ops.image.map_coordinates(image, coords, order=1)
return window, True
def _sobel_gradients(self, image):
"""Compute Sobel gradients for 2D or 3D."""
if self.ndim == 2:
return self._sobel_gradients_2d(image)
elif self.ndim == 3:
return self._sobel_gradients_3d(image)
else:
raise ValueError(f"Unsupported ndim: {self.ndim}")
def _sobel_gradients_2d(self, image):
"""Compute 2D Sobel gradients using keras.ops."""
# Standard Sobel kernels
sobel_y = ops.convert_to_tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype="float32") / 8.0
sobel_x = ops.convert_to_tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype="float32") / 8.0
h, w = ops.shape(image)[0], ops.shape(image)[1]
padded = ops.pad(image, [[1, 1], [1, 1]], mode="reflect")
# Reshape for conv: image needs (batch, height, width, channels)
img_4d = ops.reshape(padded, [1, h + 2, w + 2, 1])
sobel_y_4d = ops.reshape(sobel_y, [3, 3, 1, 1])
sobel_x_4d = ops.reshape(sobel_x, [3, 3, 1, 1])
Iy_4d = ops.conv(img_4d, sobel_y_4d, padding="valid")
Ix_4d = ops.conv(img_4d, sobel_x_4d, padding="valid")
# Reshape back to 2D
Iy = ops.reshape(Iy_4d, [h, w])
Ix = ops.reshape(Ix_4d, [h, w])
return Iy, Ix
def _sobel_gradients_3d(self, image):
"""Compute 3D Sobel gradients using keras.ops."""
# 3D Sobel kernels (separable: smooth in 2 dims, gradient in 1 dim)
# Gradient in z-direction
sobel_z = (
ops.convert_to_tensor(
[
[[-1, -2, -1], [-2, -4, -2], [-1, -2, -1]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[1, 2, 1], [2, 4, 2], [1, 2, 1]],
],
dtype="float32",
)
/ 32.0
)
# Gradient in y-direction
sobel_y = (
ops.convert_to_tensor(
[
[[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
[[-2, -4, -2], [0, 0, 0], [2, 4, 2]],
[[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
],
dtype="float32",
)
/ 32.0
)
# Gradient in x-direction
sobel_x = (
ops.convert_to_tensor(
[
[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
[[-2, 0, 2], [-4, 0, 4], [-2, 0, 2]],
[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
],
dtype="float32",
)
/ 32.0
)
d, h, w = ops.shape(image)[0], ops.shape(image)[1], ops.shape(image)[2]
padded = ops.pad(image, [[1, 1], [1, 1], [1, 1]], mode="reflect")
# Reshape for conv: image needs (batch, depth, height, width, channels)
img_5d = ops.reshape(padded, [1, d + 2, h + 2, w + 2, 1])
sobel_z_5d = ops.reshape(sobel_z, [3, 3, 3, 1, 1])
sobel_y_5d = ops.reshape(sobel_y, [3, 3, 3, 1, 1])
sobel_x_5d = ops.reshape(sobel_x, [3, 3, 3, 1, 1])
# Apply 3D convolution with 'valid' padding (we pre-padded)
Iz_5d = ops.conv(img_5d, sobel_z_5d, padding="valid")
Iy_5d = ops.conv(img_5d, sobel_y_5d, padding="valid")
Ix_5d = ops.conv(img_5d, sobel_x_5d, padding="valid")
# Reshape back to 3D
Iz = ops.reshape(Iz_5d, [d, h, w])
Iy = ops.reshape(Iy_5d, [d, h, w])
Ix = ops.reshape(Ix_5d, [d, h, w])
return (Iz, Iy, Ix)
def __repr__(self):
"""String representation."""
return (
f"LucasKanadeTracker(win_size={self.win_size}, max_level={self.max_level}, "
f"max_iterations={self.max_iterations}, epsilon={self.epsilon})"
)