Source code for zea.metrics

"""Metrics for ultrasound images."""

from functools import partial
from typing import List

import keras
import numpy as np
from keras import ops

from zea import log
from zea.backend import func_on_device, jit
from zea.func import tensor
from zea.func.tensor import translate
from zea.internal.registry import metrics_registry
from zea.internal.utils import reduce_to_signature
from zea.models.lpips import LPIPS


[docs] def get_metric(name, **kwargs): """Get metric function given name.""" metric_fn = metrics_registry[name] if not metric_fn.__name__.startswith("get_"): return partial(metric_fn, **kwargs) log.info(f"Initializing metric: {log.green(name)}") return metric_fn(**kwargs)
def _reduce_mean(array, keep_batch_dim=True): """Reduce array by taking the mean. Args: array (tensor): Input tensor of shape (..., height, width, channels) keep_batch_dim (bool): Whether to keep the batch dimensions when reducing. Default is True. """ if keep_batch_dim: return ops.mean(array, axis=(-3, -2, -1)) else: return ops.mean(array)
[docs] @metrics_registry(name="cnr", paired=True, jittable=True) def cnr(x, y): """Calculate contrast to noise ratio""" mu_x = ops.mean(x) mu_y = ops.mean(y) var_x = ops.var(x) var_y = ops.var(y) return 20 * ops.log10(ops.abs(mu_x - mu_y) / ops.sqrt((var_x + var_y) / 2))
[docs] @metrics_registry(name="contrast", paired=True, jittable=True) def contrast(x, y): """Contrast ratio""" return 20 * ops.log10(ops.mean(x) / ops.mean(y))
[docs] @metrics_registry(name="gcnr", paired=True, jittable=False) def gcnr(x, y, bins=256): """Generalized contrast-to-noise-ratio""" x = ops.convert_to_numpy(x) y = ops.convert_to_numpy(y) x = np.ravel(x) y = np.ravel(y) _, bins = np.histogram(np.concatenate((x, y)), bins=bins) f, _ = np.histogram(x, bins=bins, density=True) g, _ = np.histogram(y, bins=bins, density=True) f /= np.sum(f) g /= np.sum(g) return 1 - np.sum(np.minimum(f, g))
[docs] @metrics_registry(name="fwhm", paired=False, jittable=False) def fwhm(img): """Resolution full width half maxima""" mask = ops.nonzero(img >= 0.5 * ops.amax(img))[0] return mask[-1] - mask[0]
[docs] @metrics_registry(name="snr", paired=False, jittable=True) def snr(img): """Signal to noise ratio""" return ops.mean(img) / ops.std(img)
[docs] @metrics_registry(name="wopt_mae", paired=True, jittable=True) def wopt_mae(ref, img): """Find the optimal weight that minimizes the mean absolute error""" wopt = ops.median(ref / img) return wopt
[docs] @metrics_registry(name="wopt_mse", paired=True, jittable=True) def wopt_mse(ref, img): """Find the optimal weight that minimizes the mean squared error""" wopt = ops.sum(ref * img) / ops.sum(img * img) return wopt
[docs] @metrics_registry(name="psnr", paired=True, jittable=True) def psnr(y_true, y_pred, *, max_val=255): """Peak Signal to Noise Ratio (PSNR) for two input tensors. PSNR = 20 * log10(max_val) - 10 * log10(mean(square(y_true - y_pred))) Args: y_true (tensor): input tensor of shape (height, width, channels) with optional batch dimension. y_pred (tensor): input tensor of shape (height, width, channels) with optional batch dimension. max_val: The dynamic range of the images Returns: Tensor (float): PSNR score for each image in the batch. """ mse = _reduce_mean(ops.square(y_true - y_pred)) psnr = 20 * ops.log10(max_val) - 10 * ops.log10(mse) return psnr
[docs] @metrics_registry(name="mse", paired=True, jittable=True) def mse(y_true, y_pred): """Gives the MSE for two input tensors. Args: y_true (tensor): input tensor of shape (height, width, channels) with optional batch dimension. y_pred (tensor): input tensor of shape (height, width, channels) with optional batch dimension. Returns: (float): mean squared error between y_true and y_pred. L2 loss. """ return _reduce_mean(ops.square(y_true - y_pred))
[docs] @metrics_registry(name="mae", paired=True, jittable=True) def mae(y_true, y_pred): """Gives the MAE for two input tensors. Args: y_true (tensor): input tensor of shape (height, width, channels) with optional batch dimension. y_pred (tensor): input tensor of shape (height, width, channels) with optional batch dimension. Returns: (float): mean absolute error between y_true and y_pred. L1 loss. """ return _reduce_mean(ops.abs(y_true - y_pred))
[docs] @metrics_registry(name="ssim", paired=True, jittable=True) def ssim( a, b, *, max_val: float = 255.0, filter_size: int = 11, filter_sigma: float = 1.5, k1: float = 0.01, k2: float = 0.03, return_map: bool = False, filter_fn=None, ): """Computes the structural similarity index (SSIM) between image pairs. This function is based on the standard SSIM implementation from: Z. Wang, A. C. Bovik, H. R. Sheikh and E. P. Simoncelli, "Image quality assessment: from error visibility to structural similarity", in IEEE Transactions on Image Processing, vol. 13, no. 4, pp. 600-612, 2004. This function copied from [`dm_pix.ssim`](https://dm-pix.readthedocs.io/en/latest/api.html#dm_pix.ssim), which is part of the DeepMind's `dm_pix` library. They modeled their implementation after the `tf.image.ssim` function. Note: the true SSIM is only defined on grayscale. This function does not perform any colorspace transform. If the input is in a color space, then it will compute the average SSIM. Args: a: First image (or set of images). b: Second image (or set of images). max_val: The maximum magnitude that `a` or `b` can have. filter_size: Window size (>= 1). Image dims must be at least this small. filter_sigma: The bandwidth of the Gaussian used for filtering (> 0.). k1: One of the SSIM dampening parameters (> 0.). k2: One of the SSIM dampening parameters (> 0.). return_map: If True, will cause the per-pixel SSIM "map" to be returned. filter_fn: An optional argument for overriding the filter function used by SSIM, which would otherwise be a 2D Gaussian blur specified by filter_size and filter_sigma. Returns: Each image's mean SSIM, or a tensor of individual values if `return_map`. """ if filter_fn is None: # Construct a 1D Gaussian blur filter. hw = filter_size // 2 shift = (2 * hw - filter_size + 1) / 2 f_i = ((ops.cast(ops.arange(filter_size), "float32") - hw + shift) / filter_sigma) ** 2 filt = ops.exp(-0.5 * f_i) filt /= ops.sum(filt) # Construct a 1D convolution. def filter_fn_1(z): return tensor.correlate(z, ops.flip(filt), mode="valid") # Apply the vectorized filter along the y axis. def filter_fn_y(z): z_flat = ops.reshape(ops.moveaxis(z, -3, -1), (-1, z.shape[-3])) z_filtered_shape = ((z.shape[-4],) if z.ndim == 4 else ()) + ( z.shape[-2], z.shape[-1], -1, ) _z_filtered = ops.vectorized_map(filter_fn_1, z_flat) z_filtered = ops.moveaxis(ops.reshape(_z_filtered, z_filtered_shape), -1, -3) return z_filtered # Apply the vectorized filter along the x axis. def filter_fn_x(z): z_flat = ops.reshape(ops.moveaxis(z, -2, -1), (-1, z.shape[-2])) z_filtered_shape = ((z.shape[-4],) if z.ndim == 4 else ()) + ( z.shape[-3], z.shape[-1], -1, ) _z_filtered = ops.vectorized_map(filter_fn_1, z_flat) z_filtered = ops.moveaxis(ops.reshape(_z_filtered, z_filtered_shape), -1, -2) return z_filtered # Apply the blur in both x and y. filter_fn = lambda z: filter_fn_y(filter_fn_x(z)) mu0 = filter_fn(a) mu1 = filter_fn(b) mu00 = mu0 * mu0 mu11 = mu1 * mu1 mu01 = mu0 * mu1 sigma00 = filter_fn(a**2) - mu00 sigma11 = filter_fn(b**2) - mu11 sigma01 = filter_fn(a * b) - mu01 # Clip the variances and covariances to valid values. # Variance must be non-negative: epsilon = keras.config.epsilon() sigma00 = ops.maximum(epsilon, sigma00) sigma11 = ops.maximum(epsilon, sigma11) sigma01 = ops.sign(sigma01) * ops.minimum(ops.sqrt(sigma00 * sigma11), ops.abs(sigma01)) c1 = (k1 * max_val) ** 2 c2 = (k2 * max_val) ** 2 numer = (2 * mu01 + c1) * (2 * sigma01 + c2) denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) ssim_map = numer / denom ssim_value = ops.mean(ssim_map, axis=tuple(range(-3, 0))) return ssim_map if return_map else ssim_value
[docs] @metrics_registry(name="ncc", paired=True, jittable=True) def ncc(x, y): """Normalized cross correlation""" num = ops.sum(x * y) denom = ops.sqrt(ops.sum(x**2) * ops.sum(y**2)) return num / ops.maximum(denom, keras.config.epsilon())
[docs] @metrics_registry(name="lpips", paired=True, jittable=True) def get_lpips(image_range, clip=False): """ Get the Learned Perceptual Image Patch Similarity (LPIPS) metric. Args: image_range (list): The range of the images. Will be translated to [-1, 1] for LPIPS. clip (bool): Whether to clip the images to `image_range`. Returns: The LPIPS metric function. """ # Get the LPIPS model _lpips = LPIPS.from_preset("lpips") _lpips.trainable = False _lpips.disable_checks = True def lpips(img1, img2): """ The LPIPS metric function. Args: img1 (tensor) with shape (height, width, channels) with optional batch dimension img2 (tensor) with shape (height, width, channels) with optional batch dimension Returns (float): The LPIPS metric between img1 and img2 with shape (batch_size,) or scalar if no batch dimension. """ # clip and translate images to [-1, 1] if clip: img1 = ops.clip(img1, *image_range) img2 = ops.clip(img2, *image_range) img1 = translate(img1, image_range, [-1, 1]) img2 = translate(img2, image_range, [-1, 1]) return _lpips([img1, img2]) return lpips
[docs] class Metrics: """Class for calculating multiple paired metrics. Also useful for batch processing. Will preprocess images by translating to [0, 255], clipping, and quantizing to uint8 if specified. Example: .. doctest:: >>> from zea import metrics >>> import numpy as np >>> metrics = metrics.Metrics(["psnr", "lpips"], image_range=[0, 255]) >>> y_true = np.random.rand(4, 128, 128, 1) >>> y_pred = np.random.rand(4, 128, 128, 1) >>> result = metrics(y_true, y_pred) >>> result = {k: float(v) for k, v in result.items()} >>> print(result) # doctest: +ELLIPSIS {'psnr': ..., 'lpips': ...} """ def __init__( self, metrics: List[str], image_range: tuple, quantize: bool = False, clip: bool = False, jit_compile: bool = True, **kwargs, ): """Initialize the Metrics class. Args: metrics (list): List of metric names to calculate. image_range (tuple): The range of the images. Used for metrics like PSNR and LPIPS. quantize (bool): Whether to quantize the images to uint8 before calculating metrics. clip (bool): Whether to clip the images to `image_range` before calculating metrics. kwargs: Additional keyword arguments to pass to the metric functions. """ # Assert all metrics are paired for m in metrics: assert metrics_registry.get_parameter(m, "paired"), ( f"Metric {m} is not a paired metric." ) # Add image_range to kwargs for metrics that require it kwargs["image_range"] = image_range self.image_range = image_range # Initialize all metrics self.metrics = {} for m in metrics: jittable = metrics_registry.get_parameter(m, "jittable") metric_fn = get_metric(m, **reduce_to_signature(metrics_registry[m], kwargs)) if jit_compile and jittable: metric_fn = jit(metric_fn) self.metrics[m] = metric_fn # Other settings self.quantize = quantize self.clip = clip @staticmethod def _call_metric_fn( fun, y_true, y_pred, average_batches, return_numpy, device, mapped_batch_size=None ): num_batch_axes = max(0, ops.ndim(y_true) - 3) # Because most metric functions do not support batching, we vmap over the batch axes. # This does assume that the metric function can handle single images of shape (h, w, c). metric_fn = fun for _ in range(num_batch_axes): # recursively vmap the leading axis metric_fn = tensor.vmap( metric_fn, in_axes=0, _use_torch_vmap=True, batch_size=mapped_batch_size ) out = func_on_device(metric_fn, device, y_true, y_pred) if average_batches: out = ops.mean(out) if return_numpy: out = ops.convert_to_numpy(out) return out def _preprocess(self, tensor): tensor = translate(tensor, self.image_range, [0, 255]) if self.clip: tensor = ops.clip(tensor, 0, 255) if self.quantize: tensor = ops.cast(tensor, "uint8") tensor = ops.cast(tensor, "float32") # Some metrics require float32 return tensor
[docs] def __call__( self, y_true, y_pred, average_batches=True, mapped_batch_size=None, return_numpy=True, device=None, ): """Calculate all metrics and return as a dictionary. Assumes input shape (..., h, w, c), i.e. images of shape (h, w, c) with any number of leading batch dimensions. The metrics will be calculated on these 2d images and mapped across all leading batch dimensions. Args: y_true (tensor): Ground truth images with shape (..., h, w, c) y_pred (tensor): Predicted images with shape (..., h, w, c) average_batches (bool): Whether to average the metrics over the batch dimensions. mapped_batch_size (optional int): The batch size to use for computing metric values in parallel. You may want to decrease this if you run into memory issues, e.g. with LPIPS. return_numpy (bool): Whether to return the metrics as numpy arrays. If False, will return as tensors. device (str): The device to run the metric calculations on. If None, will use the default device. """ results = {} for name, metric in self.metrics.items(): results[name] = self._call_metric_fn( metric, self._preprocess(y_true), self._preprocess(y_pred), average_batches, return_numpy, device, mapped_batch_size=mapped_batch_size, ) return results
def _sector_reweight_image(image, sector_angle, axis): """ Reweights image according to the amount of area each row of pixels will occupy if that image is scan converted with angle sector_angle. This 'image' could be e.g. a pixelwise loss or metric. We can compute this by viewing the scan converted image as the sector of a circle with a known central angle, and radius given by depth. See: https://en.wikipedia.org/wiki/Circular_sector Params: image (ndarray or Tensor): image to be re-weighted, any shape sector_angle (float | int): angle in degrees axis (int): axis corresponding to the height/depth dimension. Returns: reweighted_image (ndarray): image with pixels reweighted to area occupied by each pixel post-scan-conversion. """ height = image.shape[axis] depths = ops.arange(height, dtype="float32") + 0.5 # center of the pixel as its depth reweighting_factors = (sector_angle / 360) * 2 * np.pi * depths # Reshape reweighting_factors to broadcast along the specified axis shape = [1] * ops.ndim(image) shape[axis] = height reweighting_factors = ops.reshape(reweighting_factors, shape) return reweighting_factors * image