"""Visualization functions for 2D and 3D ultrasound data."""
import importlib.resources
from typing import List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
from keras.ops.image import crop_images
from matplotlib.patches import PathPatch, Rectangle
from matplotlib.path import Path as pltPath
from mpl_toolkits.axes_grid1 import ImageGrid
from scipy.ndimage import zoom
from skimage import measure
from zea.display import frustum_convert_rtp2xyz
DEFAULT_STYLE = importlib.resources.files("zea") / "zea_darkmode.mplstyle"
[docs]
def set_mpl_style(style: str = None) -> None:
"""Set the matplotlib style.
Args:
style (str, optional): Path to the matplotlib style file.
Defaults to "zea_darkmode.mplstyle", which is the default
darkmode style used throughout the zea toolbox.
"""
if style is None:
style = DEFAULT_STYLE
plt.style.use(style)
[docs]
def plot_image_grid(
images: List[np.ndarray],
ncols: Optional[int] = None,
cmap: Optional[Union[str, List[str]]] = "gray",
vmin: Optional[Union[float, List[float]]] = None,
vmax: Optional[Union[float, List[float]]] = None,
interpolation: Optional[str] = "auto",
titles: Optional[List[str]] = None,
suptitle: Optional[str] = None,
aspect: Optional[Union[str, int, float, List[Union[str, int, float]]]] = None,
figsize: Optional[Tuple[float, float]] = None,
fig: Optional[plt.Figure] = None,
fig_contents: Optional[List] = None,
remove_axis: Optional[bool] = True,
background_color: Optional[str] = None,
text_color: Optional[str] = None,
axes_pad: float = 0.1,
**kwargs,
) -> Tuple[plt.Figure, List]:
"""Plot a batch of images in a grid.
Args:
images (List[np.ndarray]): batch of images.
ncols (int, optional): Number of columns. Defaults to None.
cmap (str or list, optional): Colormap. Defaults to 'gray'.
If list, cmap must be of same length as images and is set for each axis.
vmin (float or list, optional): Minimum plot value. Defaults to None.
if list vmin must be of same length as images and is set for each axis.
vmax (float or list , optional): Maximum plot value. Defaults to None.
if list vmax must be of same length as images and is set for each axis.
interpolation (str, optional): Interpolation method that mpl uses. Defaults to 'auto'.
titles (list, optional): List of titles for subplots. Defaults to None.
suptitle (str, optional): Title for the plot. Defaults to None.
aspect (optional): Aspect ratio for imshow.
figsize (tuple, optional): Figure size. Defaults to None.
fig (figure, optional): Matplotlib figure object. Defaults to None. Can
be used to plot on an existing figure.
fig_contents (list, optional): List of matplotlib image objects. Defaults to None.
remove_axis (bool, optional): Whether to remove axis. Defaults to True. If False, axes r
emain but spines are colored to background and ticks/labels are hidden,
allowing later label drawing to remain visible.
background_color (str, optional): Background color. Defaults to None. (Matplotlib default)
text_color (str, optional): Text color. Defaults to None. (Matplotlib default)
axes_pad (float, optional): Padding between axes. Defaults to 0.1.
**kwargs: arguments for plt.Figure.
Returns:
fig (figure): Matplotlib figure object
fig_contents (list): List of matplotlib image objects.
Example:
.. doctest::
>>> from zea.visualize import plot_image_grid
>>> import numpy as np
>>> images = [np.random.rand(128, 128) for _ in range(6)]
>>> fig, fig_contents = plot_image_grid(
... images,
... ncols=3,
... cmap="gray",
... vmin=0,
... vmax=1,
... )
"""
if ncols is None:
factors = [i for i in range(1, len(images) + 1) if len(images) % i == 0]
ncols = factors[len(factors) // 2] if len(factors) else len(images) // 4 + 1
nrows = int(len(images) / ncols) + int(len(images) % ncols)
images = [images[i] if len(images) > i else None for i in range(nrows * ncols)]
aspect_ratio = images[0].shape[1] / images[0].shape[0]
if figsize is None:
figsize = (ncols * 2, nrows * 2 / aspect_ratio)
# get default colors for matplotlib
if background_color is None:
background_color = plt.rcParams["axes.facecolor"]
if text_color is None:
text_color = plt.rcParams["text.color"]
# either supply both fig and fig_contents or neither
assert (fig is None) == (fig_contents is None), "Supply both fig and fig_contents or neither"
if fig is None:
fig = plt.figure(figsize=figsize, **kwargs)
axes = ImageGrid(fig, 111, nrows_ncols=(nrows, ncols), axes_pad=axes_pad)
if background_color is not None:
fig.patch.set_facecolor(background_color)
fig.set_layout_engine("tight", pad=0.1)
else:
axes = fig.axes[: len(images)]
if isinstance(cmap, str):
cmap = [cmap] * len(images)
else:
if cmap is None:
cmap = [None] * len(images)
assert len(cmap) == len(images), (
f"cmap must be a string or list of strings of length {len(images)}, but got {cmap}"
)
if isinstance(vmin, (int, float)):
vmin = [vmin] * len(images)
else:
if vmin is None:
vmin = [None] * len(images)
assert len(vmin) == len(images), (
f"vmin must be a float or list of floats of length {len(images)}, but got {vmin}"
)
if isinstance(vmax, (int, float)):
vmax = [vmax] * len(images)
else:
if vmax is None:
vmax = [None] * len(images)
assert len(vmax) == len(images), (
f"vmax must be a float or list of floats of length {len(images)}, but got {vmax}"
)
if isinstance(aspect, (int, float, str)):
aspect = [aspect] * len(images)
else:
if aspect is None:
aspect = [None] * len(images)
assert len(aspect) == len(images), (
"aspect must be a float, int, str, or list of these "
f"of length {len(images)}, but got {aspect}"
)
if fig_contents is None:
fig_contents = [None for _ in range(len(images))]
for i, ax in enumerate(axes):
image = images[i]
if fig_contents[i] is None:
im = ax.imshow(
image,
cmap=cmap[i],
vmin=vmin[i],
vmax=vmax[i],
aspect=aspect[i],
interpolation=interpolation,
)
fig_contents[i] = im
else:
fig_contents[i].set_data(image)
if remove_axis:
ax.axis("off")
else:
for spine in ax.spines.values():
# spine.set_visible(False)
spine.set_color(background_color)
ax.tick_params(
axis="both",
which="both",
bottom=False,
top=False,
left=False,
right=False,
)
ax.set_xticklabels([])
ax.set_yticklabels([])
if titles:
ax.set_title(titles[i], color=text_color)
if suptitle:
fig.suptitle(suptitle, color=text_color)
fig.set_layout_engine("none")
# use bbox_inches="tight" for proper tight layout when saving
return fig, fig_contents
[docs]
def plot_quadrants(ax, array, fixed_coord, cmap, slice_index, stride=1, centroid=None, **kwargs):
"""
For a given 3D array, plot a plane with fixed_coord using four individual quadrants.
Args:
ax (matplotlib.axes.Axes3DSubplot): The 3D axis to plot on.
array (numpy.ndarray): The 3D array to be plotted.
fixed_coord (str): The coordinate to be fixed ('x', 'y', or 'z').
cmap (str): The colormap to be used for plotting.
slice_index (int or None): The index of the slice to be plotted.
If None, the middle slice is used.
stride (int, optional): The stride step for plotting. Defaults to 1.
centroid (tuple, optional): centroid around which to break the quadrants.
If None, the middle of the image is used.
**kwargs: Additional keyword arguments for the plot_surface method.
Returns:
matplotlib.axes.Axes3DSubplot: The axis with the plotted quadrants.
"""
nx, ny, nz = array.shape
index = {
"x": (
slice_index if slice_index is not None else nx // 2,
slice(None),
slice(None),
),
"y": (
slice(None),
slice_index if slice_index is not None else ny // 2,
slice(None),
),
"z": (
slice(None),
slice(None),
slice_index if slice_index is not None else nz // 2,
),
}[fixed_coord]
plane_data = array[index]
if centroid is None:
centroid = [x // 2 for x in array.shape]
coords = {"x": (1, 2), "y": (0, 2), "z": (0, 1)}
n0, n1 = (centroid[i] for i in coords[fixed_coord])
quadrants = [
plane_data[:n0, :n1],
plane_data[:n0, n1:],
plane_data[n0:, :n1],
plane_data[n0:, n1:],
]
min_val = np.nanmin(array)
max_val = np.nanmax(array)
cmap = plt.get_cmap(cmap)
for i, quadrant in enumerate(quadrants):
facecolors = cmap((quadrant - min_val) / (max_val - min_val))
if fixed_coord == "x":
Y, Z = np.mgrid[: quadrant.shape[0] + 1, : quadrant.shape[1] + 1]
X = (slice_index if slice_index is not None else nx // 2) * np.ones_like(Y)
Y_offset = (i // 2) * n0
Z_offset = (i % 2) * n1
ax.plot_surface(
X,
Y + Y_offset,
Z + Z_offset,
rstride=stride,
cstride=stride,
facecolors=facecolors,
shade=False,
**kwargs,
)
elif fixed_coord == "y":
X, Z = np.mgrid[: quadrant.shape[0] + 1, : quadrant.shape[1] + 1]
Y = (slice_index if slice_index is not None else ny // 2) * np.ones_like(X)
X_offset = (i // 2) * n0
Z_offset = (i % 2) * n1
ax.plot_surface(
X + X_offset,
Y,
Z + Z_offset,
rstride=stride,
cstride=stride,
facecolors=facecolors,
shade=False,
**kwargs,
)
elif fixed_coord == "z":
X, Y = np.mgrid[: quadrant.shape[0] + 1, : quadrant.shape[1] + 1]
Z = (slice_index if slice_index is not None else nz // 2) * np.ones_like(X)
X_offset = (i // 2) * n0
Y_offset = (i % 2) * n1
ax.plot_surface(
X + X_offset,
Y + Y_offset,
Z,
rstride=stride,
cstride=stride,
facecolors=facecolors,
shade=False,
**kwargs,
)
return ax
[docs]
def plot_biplanes(
volume,
cmap="gray",
resolution=1.0,
stride=1,
slice_x=None,
slice_y=None,
slice_z=None,
show_axes=None,
fig=None,
ax=None,
**kwargs,
):
"""
Plot three intersecting planes from a 3D volume in 3D space.
Also known as ultrasound biplane visualization.
Args:
volume (ndarray): 3D numpy array representing the volume to be plotted.
cmap (str, optional): Colormap to be used for plotting. Defaults to "gray".
resolution (float, optional): Resolution factor for the volume. Defaults to 1.0.
stride (int, optional): Stride for plotting the quadrants. Defaults to 1.
slice_x (int, optional): Index for the slice in the x-plane. Defaults to None.
slice_y (int, optional): Index for the slice in the y-plane. Defaults to None.
slice_z (int, optional): Index for the slice in the z-plane. Defaults to None.
show_axes (dict, optional): Dictionary to specify axis labels and extents.
Defaults to None.
fig (matplotlib.figure.Figure, optional): Matplotlib figure object.
Defaults to None. Can be used to reuse the figure in a loop.
ax (matplotlib.axes.Axes3DSubplot, optional): Matplotlib 3D axes object.
Defaults to None. Can be used to reuse the axes in a loop.
**kwargs: Additional keyword arguments for the plot_surface method.
Returns:
tuple: A tuple containing the figure and axes objects (fig, ax).
Raises:
AssertionError: If none of slice_x, slice_y, or slice_z are provided.
"""
assert slice_x is not None or slice_y is not None or slice_z is not None, (
"At least one slice index must be set."
)
volume = zoom(volume, (resolution, resolution, resolution), order=1)
# Adjust slice indices if resolution < 1
if resolution < 1:
if slice_x is not None:
slice_x = int(slice_x * resolution)
if slice_y is not None:
slice_y = int(slice_y * resolution)
if slice_z is not None:
slice_z = int(slice_z * resolution)
# volume is grid_size_z, grid_size_x, n_y -> grid_size_x, n_y, grid_size_z
volume = np.transpose(volume, (1, 2, 0))
volume = np.flip(volume, axis=2) # Flip the z-axis
if fig is None:
fig = plt.figure()
if ax is None:
ax = fig.add_subplot(projection="3d")
ax.set_box_aspect(volume.shape)
# Remove background and axes faces
ax.grid(False)
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
if slice_x is not None:
plot_quadrants(ax, volume, "x", cmap=cmap, slice_index=slice_x, stride=stride, **kwargs)
if slice_y is not None:
plot_quadrants(ax, volume, "y", cmap=cmap, slice_index=slice_y, stride=stride, **kwargs)
if slice_z is not None:
plot_quadrants(ax, volume, "z", cmap=cmap, slice_index=slice_z, stride=stride, **kwargs)
# Optionally show axes
if show_axes:
ax.set_xlabel(show_axes.get("x", ""))
ax.set_ylabel(show_axes.get("y", ""))
ax.set_zlabel(show_axes.get("z", ""))
if "x_extent" in show_axes:
ax.set_xticks(np.linspace(0, volume.shape[0], len(show_axes["x_extent"])))
ax.set_xticklabels(show_axes["x_extent"])
if "y_extent" in show_axes:
ax.set_yticks(np.linspace(0, volume.shape[1], len(show_axes["y_extent"])))
ax.set_yticklabels(show_axes["y_extent"])
if "z_extent" in show_axes:
ax.set_zticks(np.linspace(0, volume.shape[2], len(show_axes["z_extent"])))
ax.set_zticklabels(show_axes["z_extent"])
else:
ax.set_axis_off()
return fig, ax
[docs]
def plot_frustum_vertices(
rho_range,
theta_range,
phi_range,
num_points=20,
phi_plane=None,
theta_plane=None,
rho_plane=None,
fig=None,
ax=None,
frustum_style=None,
phi_style=None,
theta_style=None,
rho_style=None,
):
"""
Plots the vertices of a frustum in spherical coordinates and highlights specified planes.
Args:
rho_range (tuple): Range of rho values (min, max).
theta_range (tuple): Range of theta values (min, max).
phi_range (tuple): Range of phi values (min, max).
num_points (int, optional): Number of points to generate along each edge.
Defaults to 20.
phi_plane (float or list, optional): Value(s) of phi at which to plot plane(s).
Defaults to None.
theta_plane (float or list, optional): Value(s) of theta at which to plot plane(s).
Defaults to None.
rho_plane (float or list, optional): Value(s) of rho at which to plot plane(s).
Defaults to None.
fig (matplotlib.figure.Figure, optional): Figure object to plot on.
Defaults to None. Can be used to reuse the figure in a loop.
ax (matplotlib.axes.Axes3DSubplot, optional): Axes object to plot on.
Defaults to None. Can be used to reuse the axes in a loop.
frustum_style (dict, optional): Style dictionary for frustum edges. Can include
'color', 'linestyle', 'linewidth', 'alpha', etc.
Defaults to {'color': 'blue', 'linestyle': '-', 'linewidth': 2}.
phi_style (dict, optional): Style dictionary for phi plane(s). Can include
'color', 'linestyle', 'linewidth', 'alpha', etc.
Defaults to {'color': 'yellow', 'linestyle': '-'}.
theta_style (dict, optional): Style dictionary for theta plane(s). Can include
'color', 'linestyle', 'linewidth', 'alpha', etc.
Defaults to {'color': 'green', 'linestyle': '--'}.
rho_style (dict, optional): Style dictionary for rho plane(s). Can include
'color', 'linestyle', 'linewidth', 'alpha', etc.
Defaults to {'color': 'red', 'linestyle': '--'}.
Returns:
tuple: A tuple containing the figure and axes objects (fig, ax).
Raises:
ValueError: If no plane is specified (phi_plane, theta_plane, or rho_plane).
Example:
.. doctest::
>>> from zea.visualize import plot_frustum_vertices
>>> rho_range = [0.1, 10] # in mm
>>> theta_range = [-0.6, 0.6] # in rad
>>> phi_range = [-0.6, 0.6] # in rad
>>> fig, ax = plot_frustum_vertices(
... rho_range,
... theta_range=theta_range,
... phi_range=phi_range,
... phi_plane=0,
... phi_style={"color": "red", "linestyle": "--", "linewidth": 2},
... theta_plane=0.2,
... theta_style={"color": "green", "linestyle": ":", "alpha": 0.7},
... frustum_style={"color": "blue", "linewidth": 1.5},
... )
"""
# Convert single values to lists
phi_plane = [phi_plane] if isinstance(phi_plane, (int, float)) else phi_plane
theta_plane = [theta_plane] if isinstance(theta_plane, (int, float)) else theta_plane
rho_plane = [rho_plane] if isinstance(rho_plane, (int, float)) else rho_plane
# Ensure at least one plane is specified
if all(p is None for p in [phi_plane, theta_plane, rho_plane]):
raise ValueError("At least one plane must be specified")
# Build style dictionaries with defaults
if frustum_style is None:
frustum_style = {"color": "blue", "linestyle": "-", "linewidth": 2}
if phi_style is None:
phi_style = {"color": "yellow", "linestyle": "-"}
if theta_style is None:
theta_style = {"color": "green", "linestyle": "--"}
if rho_style is None:
rho_style = {"color": "red", "linestyle": "--"}
# Define edges of the frustum
edges = []
# Edges along rho (vertical edges)
for theta in theta_range:
for phi in phi_range:
edges.append(((rho_range[0], theta, phi), (rho_range[1], theta, phi)))
# Edges along theta (near and far planes)
for rho in rho_range:
for phi in phi_range:
edges.append(((rho, theta_range[0], phi), (rho, theta_range[1], phi)))
# Edges along phi (near and far planes)
for rho in rho_range:
for theta in theta_range:
edges.append(((rho, theta, phi_range[0]), (rho, theta, phi_range[1])))
# Function to generate edge points
def generate_edge_points(start, end, num_points):
rho_points = np.linspace(start[0], end[0], num_points)
theta_points = np.linspace(start[1], end[1], num_points)
phi_points = np.linspace(start[2], end[2], num_points)
return rho_points, theta_points, phi_points
# Collect all points to determine axes limits
all_points = []
for edge in edges:
rho_pts, theta_pts, phi_pts = generate_edge_points(edge[0], edge[1], num_points)
x, y, z = frustum_convert_rtp2xyz(rho_pts, theta_pts, phi_pts)
all_points.extend(zip(x, y, -z)) # Flip z-axis
all_points = np.array(all_points)
x_min, x_max = np.min(all_points[:, 0]), np.max(all_points[:, 0])
y_min, y_max = np.min(all_points[:, 1]), np.max(all_points[:, 1])
z_min, z_max = np.min(all_points[:, 2]), np.max(all_points[:, 2])
if fig is None:
fig = plt.figure()
if ax is None:
ax = fig.add_subplot(111, projection="3d")
def _plot_edges(edges, **kwargs):
for edge in edges:
rho_pts, theta_pts, phi_pts = generate_edge_points(edge[0], edge[1], num_points)
x, y, z = frustum_convert_rtp2xyz(rho_pts, theta_pts, phi_pts)
ax.plot(x, y, -z, **kwargs)
# Plot frustum edges
_plot_edges(edges, **frustum_style)
def get_plane_edges(plane_value, plane_type):
"""Generate edges for a specific plane type (phi, theta, or rho)"""
if plane_type == "phi":
return [
(
(rho_range[0], theta_range[0], plane_value),
(rho_range[1], theta_range[0], plane_value),
),
(
(rho_range[0], theta_range[1], plane_value),
(rho_range[1], theta_range[1], plane_value),
),
(
(rho_range[0], theta_range[0], plane_value),
(rho_range[0], theta_range[1], plane_value),
),
(
(rho_range[1], theta_range[0], plane_value),
(rho_range[1], theta_range[1], plane_value),
),
]
elif plane_type == "theta":
return [
(
(rho_range[0], plane_value, phi_range[0]),
(rho_range[1], plane_value, phi_range[0]),
),
(
(rho_range[0], plane_value, phi_range[1]),
(rho_range[1], plane_value, phi_range[1]),
),
(
(rho_range[0], plane_value, phi_range[0]),
(rho_range[0], plane_value, phi_range[1]),
),
(
(rho_range[1], plane_value, phi_range[0]),
(rho_range[1], plane_value, phi_range[1]),
),
]
else: # rho
return [
(
(plane_value, theta_range[0], phi_range[0]),
(plane_value, theta_range[1], phi_range[0]),
),
(
(plane_value, theta_range[0], phi_range[1]),
(plane_value, theta_range[1], phi_range[1]),
),
(
(plane_value, theta_range[0], phi_range[0]),
(plane_value, theta_range[0], phi_range[1]),
),
(
(plane_value, theta_range[1], phi_range[0]),
(plane_value, theta_range[1], phi_range[1]),
),
]
# Plot plane edges
plane_configs = [
(phi_plane, "phi", phi_style),
(theta_plane, "theta", theta_style),
(rho_plane, "rho", rho_style),
]
for planes, plane_type, style_dict in plane_configs:
if planes is not None:
for plane_value in planes:
plane_edges = get_plane_edges(plane_value, plane_type)
_plot_edges(plane_edges, **style_dict)
# Set axes properties
ax.set_xlim([x_min, x_max])
ax.set_ylim([y_min, y_max])
ax.set_zlim([z_min, z_max])
ax.set_axis_off()
ax.grid(False)
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
return fig, ax
[docs]
def visualize_matrix(matrix, font_color="white", **kwargs):
"""
Visualize a matrix with the values in each cell.
"""
fig, ax = plt.subplots()
cax = ax.imshow(matrix, **kwargs)
fig.colorbar(cax)
for (j, i), label in np.ndenumerate(matrix):
ax.text(i, j, f"{label:.2f}", ha="center", va="center", color=font_color)
return fig
[docs]
def pad_or_crop_extent(image, extent, target_extent):
"""Pads and/or crops the extent of an image to match a target extent.
This is useful for side by side comparison of images with different extents.
Args:
image (np.ndarray): The input image to be padded and/or cropped.
Only 2D images are supported. Image shape must be (grid_size_z, grid_size_x).
extent (tuple): The current extent of the image in the format
(x_min, x_max, z_min, z_max).
target_extent (tuple): The target extent to match in the format
(x_min, x_max, z_min, z_max).
Returns:
np.ndarray: The padded and/or cropped image.
"""
x_min, x_max, z_min, z_max = extent
target_x_min, target_x_max, target_z_min, target_z_max = target_extent
pixel_per_mm = np.array(image.shape) / np.array([z_max - z_min, x_max - x_min])
pixels_to_add_left = int((x_min - target_x_min) * pixel_per_mm[1])
pixels_to_add_right = int((target_x_max - x_max) * pixel_per_mm[1])
pixels_to_add_top = int((z_min - target_z_min) * pixel_per_mm[0])
pixels_to_add_bottom = int((target_z_max - z_max) * pixel_per_mm[0])
# crop if negative, pad if positive
pixels_to_crop_left = max(0, -pixels_to_add_left)
pixels_to_crop_right = max(0, -pixels_to_add_right)
pixels_to_crop_top = max(0, -pixels_to_add_top)
pixels_to_crop_bottom = max(0, -pixels_to_add_bottom)
pixels_to_pad_left = max(0, pixels_to_add_left)
pixels_to_pad_right = max(0, pixels_to_add_right)
pixels_to_pad_top = max(0, pixels_to_add_top)
pixels_to_pad_bottom = max(0, pixels_to_add_bottom)
# Crop the image
image_cropped = crop_images(
image[..., None],
pixels_to_crop_top,
pixels_to_crop_left,
pixels_to_crop_bottom,
pixels_to_crop_right,
data_format="channels_last",
)[..., 0]
# Pad the image
image_padded = np.pad(
image_cropped,
((pixels_to_pad_top, pixels_to_pad_bottom), (pixels_to_pad_left, pixels_to_pad_right)),
mode="constant",
constant_values=0,
)
return image_padded
[docs]
def plot_rectangle_from_mask(ax, mask, **kwargs):
"""Plots a rectangle box to axis from mask array.
Is a simplified version of plot_shape_from_mask for rectangles.
Useful for displaying bounding boxes on top of images.
Args:
ax (plt.ax): matplotlib axis
mask (ndarray): numpy array with rectangle non-zero
box defining the region of interest.
Kwargs:
edgecolor (str): color of the shape's edge
facecolor (str): color of the shape's face
linewidth (int): width of the shape's edge
Returns:
matplotlib.patches.Rectangle: the added rectangle patch, or None if mask is empty.
"""
ys, xs = np.where(mask)
if ys.size == 0 or xs.size == 0:
return None
y1, y2 = ys.min(), ys.max()
x1, x2 = xs.min(), xs.max()
rect = Rectangle((x1, y1), x2 - x1 + 1, y2 - y1 + 1, **kwargs)
return ax.add_patch(rect)
[docs]
def plot_shape_from_mask(ax, mask, **kwargs):
"""Plots a shape to axis from mask array.
Is useful for displaying irregular shapes such as segmentations
on top of images.
Args:
ax (plt.ax): matplotlib axis
mask (ndarray): numpy array with non-zero
shape defining the region of interest.
Kwargs:
edgecolor (str): color of the shape's edge
facecolor (str): color of the shape's face
linewidth (int): width of the shape's edge
Returns:
list[matplotlib.patches.PathPatch]: list of matplotlib patch objects
added to the axis.
Example:
.. code-block:: python
import matplotlib.pyplot as plt
import numpy as np
from zea.visualize import plot_shape_from_mask
y, x = np.ogrid[-50:50, -50:50]
mask = x**2 + y**2 <= 30**2
fig, ax = plt.subplots()
ax.imshow(np.random.rand(100, 100), cmap="gray")
plot_shape_from_mask(ax, mask, edgecolor="red", alpha=0.5)
"""
# Pad mask to ensure edge contours are found
padded_mask = np.pad(mask, pad_width=1, mode="constant", constant_values=0)
contours = measure.find_contours(padded_mask, 0.5)
patches = []
for contour in contours:
# Remove padding offset
contour -= 1
path = pltPath(contour[:, ::-1])
patch = PathPatch(path, **kwargs)
patches.append(ax.add_patch(patch))
return patches