"""EchoNetLVH model for segmentation of PLAX view cardiac ultrasound.
To try this model, simply load one of the available presets:
.. doctest::
>>> from zea.models.echonetlvh import EchoNetLVH
>>> model = EchoNetLVH.from_preset("echonetlvh")
.. important::
This is a ``zea`` implementation of the model.
For the original paper and code, see `here <https://echonet.github.io/lvh/>`_.
Duffy, Grant, et al.
"High-throughput precision phenotyping of left ventricular hypertrophy with cardiovascular deep learning."
*JAMA cardiology 7.4 (2022): 386-395*
.. seealso::
A tutorial notebook where this model is used:
:doc:`../notebooks/agent/task_based_perception_action_loop`.
""" # noqa: E501
import numpy as np
from keras import ops
from zea.func.tensor import translate
from zea.internal.registry import model_registry
from zea.models.base import BaseModel
from zea.models.deeplabv3 import DeeplabV3Plus
from zea.models.preset_utils import register_presets
from zea.models.presets import echonet_lvh_presets
[docs]
@model_registry(name="echonetlvh")
class EchoNetLVH(BaseModel):
"""
EchoNet Left Ventricular Hypertrophy (LVH) model for echocardiogram analysis.
This model performs semantic segmentation on echocardiogram images to identify
key anatomical landmarks for measuring left ventricular wall thickness:
- **LVPWd_1**: Left Ventricular Posterior Wall point 1
- **LVPWd_2**: Left Ventricular Posterior Wall point 2
- **IVSd_1**: Interventricular Septum point 1
- **IVSd_2**: Interventricular Septum point 2
The model outputs 4-channel logits corresponding to heatmaps for each landmark.
For more information, see the original project page:
https://echonet.github.io/lvh/
"""
def __init__(self, **kwargs):
"""
Initialize the EchoNetLVH model.
Args:
**kwargs: Additional keyword arguments passed to BaseModel
"""
super().__init__(**kwargs)
# Pre-computed coordinate grid for efficient processing
self.coordinate_grid = ops.stack(
ops.cast(ops.convert_to_tensor(np.indices((224, 224))), "float32"), axis=-1
)
# Initialize the underlying segmentation network
self.network = DeeplabV3Plus(image_shape=(224, 224, 3), num_classes=4)
[docs]
def call(self, inputs):
"""
Forward pass of the model.
Args:
inputs (Tensor): Input images of shape [B, H, W, C]. They should
be scan converted, with pixel values in range [0, 255].
Returns:
Tensor: Logits of shape [B, H, W, 4] with 4 channels for each landmark
"""
assert len(ops.shape(inputs)) == 4
# Store original dimensions for output resizing
original_size = ops.shape(inputs)[1:3]
# Resize to network input size
inputs_resized = ops.image.resize(inputs, size=(224, 224))
# Get network predictions
logits = self.network(inputs_resized)
# Resize logits back to original input dimensions
logits_output = ops.image.resize(logits, original_size)
return logits_output
[docs]
def expected_coordinate(self, mask, coordinate_grid=None):
"""
Compute the expected coordinate (center-of-mass) of a heatmap.
This implements a differentiable version of taking the max of a heatmap
by computing the weighted average of coordinates.
Reference: https://arxiv.org/pdf/1711.08229
Args:
mask (Tensor): Heatmap of shape [B, H, W]
coordinate_grid (Tensor, optional): Grid of coordinates. If None,
uses self.coordinate_grid
Returns:
Tensor: Expected coordinates of shape [B, 2] in (x, y) format
"""
if coordinate_grid is None:
coordinate_grid = self.coordinate_grid
# Ensure mask values are non-negative and normalized
mask_clipped = ops.clip(mask, 0, None)
mask_normed = mask_clipped / ops.max(mask_clipped)
def safe_normalize(m):
mask_sum = ops.sum(m)
return ops.where(mask_sum > 0, m / mask_sum, m)
coordinate_probabilities = ops.map(safe_normalize, mask_normed)
# Add dimension for broadcasting with coordinate grid
coordinate_probabilities = ops.expand_dims(coordinate_probabilities, axis=-1)
# Compute weighted average of coordinates
expected_coordinate = ops.sum(
ops.expand_dims(coordinate_grid, axis=0) * coordinate_probabilities,
axis=(1, 2),
)
# Flip to convert from (y, x) to (x, y) format for euclidean distance calculation
return ops.flip(expected_coordinate, axis=-1)
[docs]
def overlay_labels_on_image(self, image, label, alpha=0.5):
"""
Overlay predicted heatmaps and connecting lines on the input image.
Args:
image (Tensor): Input image of shape [H, W] or [H, W, C]
label (Tensor): Predicted logits of shape [H, W, 4]
alpha (float): Blending factor for overlay (0=transparent, 1=opaque)
Returns:
ndarray: Image with overlaid heatmaps and measurements of shape [H, W, 3]
"""
try:
import cv2
except ImportError as exc:
raise ImportError(
"OpenCV is required for `EchoNetLVH.overlay_labels_on_image`. "
"Please install it with 'pip install opencv-python' or "
"'pip install opencv-python-headless'."
) from exc
# Color scheme for each landmark
overlay_colors = np.array(
[
[1, 1, 0], # Yellow (LVPWd_X1)
[1, 0, 1], # Magenta (LVPWd_X2)
[0, 1, 1], # Cyan (IVSd_X1)
[0, 1, 0], # Green (IVSd_X2)
],
)
# Convert to numpy and ensure RGB format
image = ops.convert_to_numpy(image)
label = ops.convert_to_numpy(label)
if image.ndim == 2:
image = np.stack([image] * 3, axis=-1)
elif image.shape[-1] == 1:
image = np.repeat(image, 3, axis=-1)
else:
image = image.copy()
# Normalize each channel to [0, 1] for proper visualization
label = np.clip(label, 0, None)
for ch in range(label.shape[-1]):
max_val = np.max(label[..., ch])
if max_val > 0:
label[..., ch] = label[..., ch] / max_val
# Initialize overlay and tracking variables
overlay = np.zeros_like(image, dtype=np.float32)
centers = []
# Process each landmark channel
for ch in range(4):
# Square the mask to enhance peak responses
mask = label[..., ch] ** 2
color = overlay_colors[ch]
# Find center of mass for this channel
center_coords = self.expected_coordinate(ops.expand_dims(mask, axis=0))
center_x = ops.convert_to_numpy(center_coords[0, 0])
center_y = ops.convert_to_numpy(center_coords[0, 1])
# Bounds check before conversion to int
if 0 <= center_x < image.shape[1] and 0 <= center_y < image.shape[0]:
center = (int(center_x), int(center_y))
else:
center = None
if center is not None:
# Blend heatmap with overlay
mask_alpha = mask * alpha
for c in range(3):
overlay[..., c] += mask_alpha * color[c]
centers.append(center)
# Draw connecting lines between consecutive landmarks
for i in range(3):
pt1, pt2 = centers[i], centers[i + 1]
if pt1 is not None and pt2 is not None:
color = tuple(int(x) for x in overlay_colors[i])
# Create line mask
line_mask = np.zeros(image.shape[:2], dtype=np.uint8)
cv2.line(line_mask, pt1, pt2, color=1, thickness=2)
# Apply line to overlay
for c in range(3):
overlay[..., c][line_mask.astype(bool)] = color[c] * alpha
# Blend overlay with original image
overlay = np.clip(overlay, 0, 1)
out = image.astype(np.float32)
blend_mask = np.any(overlay > 0.02, axis=-1)
out[blend_mask] = (1 - alpha) * out[blend_mask] + overlay[blend_mask]
return np.clip(out, 0, 1)
[docs]
def visualize_logits(self, images, logits):
"""
Create visualization of model predictions overlaid on input images.
Args:
images (Tensor): Input images of shape [B, H, W, C]
logits (Tensor): Model predictions of shape [B, H, W, 4]
Returns:
Tensor: Images with overlaid predictions of shape [B, H, W, 3]
"""
# Store original dimensions for final output
original_size = ops.shape(images)[1:3]
# Resize to standard processing size
images_resized = ops.image.resize(images, size=(224, 224), interpolation="nearest")
logits_resized = ops.image.resize(logits, size=(224, 224), interpolation="nearest")
# Normalize images to [0, 1] range
images_clipped = ops.clip(images_resized, 0, 255)
images = translate(images_clipped, range_from=(0, 255), range_to=(0, 1))
# Generate overlays for each image in the batch
images_with_overlay = []
for img, logit_heatmap in zip(images, logits_resized):
overlay = self.overlay_labels_on_image(img, logit_heatmap)
images_with_overlay.append(overlay)
# Stack results and resize back to original dimensions
images_with_overlay = np.stack(images_with_overlay, axis=0)
return ops.image.resize(images_with_overlay, original_size)
# Register model presets
register_presets(echonet_lvh_presets, EchoNetLVH)