Source code for zea.data.convert.echonet

"""
Script to convert the EchoNet database to zea format.

.. note::
    Will segment the images and convert them to polar coordinates.

For more information about the dataset, resort to the following links:

- The original dataset can be found at `this link <https://stanfordaimi.azurewebsites.net/datasets/834e1cd1-92f7-4268-9daa-d359198b310a>`_.
- The project page is available `here <https://echonet.github.io/>`_.

"""

import os
from concurrent.futures import ProcessPoolExecutor, as_completed
from multiprocessing import Value
from pathlib import Path

import numpy as np
import yaml
from scipy.interpolate import griddata
from tqdm import tqdm

from zea import log
from zea.data import generate_zea_dataset
from zea.data.convert.utils import load_avi, unzip
from zea.func.tensor import translate


[docs] def segment(tensor, number_erasing=0, min_clip=0): """Segments the background of the echonet images by setting it to 0 and creating a hard edge. Args: tensor (ndarray): Input image (sc) with 3 dimensions. (N, 112, 112) number_erasing (float, optional): number to fill the background with. min_clip (float, optional): If > 0, values on the computed cone edge will be clipped to be at least this value. Defaults to 0. Returns: tensor (ndarray): Segmented matrix of same dimensions as input """ # Start with the upper part # Height of the diagonal lines for the columns [0, 112] rows_left = np.linspace(67, 7, 61) rows_right = np.linspace(7, 57, 51) rows = np.concatenate([rows_left, rows_right], axis=0) for idx, row in enumerate(rows.astype(np.int32)): # Set everything above the edge to the number_erasing value. # Rows count up from 0 to 112 so row-1 is above. tensor[:, 0 : row - 1, idx] = number_erasing # Set minimum values for the edge if min_clip > 0: tensor[:, row, idx] = np.clip(tensor[:, row, idx], min_clip, 1) # Bottom left curve (manual fit) cols_left = np.linspace(0, 20, 21).astype(np.int32) rows_left = np.array( [ 102, 103, 103, 104, 104, 105, 105, 106, 106, 107, 107, 107, 108, 108, 109, 109, 109, 110, 110, 111, 111, ] ) # Bottom right curve (manual fit) cols_right = np.linspace(89, 111, 23).astype(np.int32) rows_right = np.array( [ 111, 111, 111, 110, 110, 110, 109, 109, 109, 108, 108, 107, 107, 107, 106, 106, 105, 105, 104, 104, 103, 103, 102, ] ) rows = np.concatenate([rows_left, rows_right], axis=0) cols = np.concatenate([cols_left, cols_right], axis=0) for row, col in zip(rows, cols): # Set everything under the edge to the number_erasing value. # Rows count up from 0 to 112 so row-1 is above. tensor[:, row:, col] = number_erasing # Set minimum values for the edge if min_clip > 0: tensor[:, row - 1, col] = np.clip(tensor[:, row - 1, col], min_clip, 1) return tensor
[docs] def accept_shape(tensor): """Acceptance algorithm that determines whether to reject an image based on left and right corner data. Args: tensor (ndarray): Input image (sc) with 2 dimensions. (112, 112) Returns: decision (bool): Whether or not the tensor should be rejected. """ decision = True # Test one, check if left bottom corner is populated with values rows_lower = np.linspace(78, 47, 21).astype(np.int32) rows_upper = np.linspace(67, 47, 21).astype(np.int32) counter = 0 for idx, row in enumerate(rows_lower): counter += np.sum(tensor[rows_upper[idx] : row, idx]) # If it is not populated, reject the image if counter < 0.1: decision = False # Test two, check if the bottom right cornered with values (that are not artifacts) cols = np.linspace(70, 111, 42).astype(np.int32) rows_bot = np.linspace(17, 57, 42).astype(np.int32) rows_top = np.linspace(17, 80, 42).astype(np.int32) # List all the values counter = [] for i, col in enumerate(cols): counter += [tensor[rows_bot[i] : rows_top[i], col]] flattened_counter = [float(item) for sublist in counter for item in sublist] # Sort and exclude the first 50 (likely artifacts) flattened_counter.sort(reverse=True) value = sum(flattened_counter[100:]) # Reject if the baseline is too low if value < 5: decision = False return decision
[docs] def rotate_coordinates(data_points, degrees): """Function that rotates the datapoints by a certain degree. Args: data_points (ndarray): tensor containing [N,2] (x and y) datapoints. degrees (int): angle to rotate the datapoints with Returns: rotated_points (ndarray): the rotated data_points. """ angle_radians = np.radians(degrees) cos_angle = np.cos(angle_radians) sin_angle = np.sin(angle_radians) rotation_matrix = np.array([[cos_angle, -sin_angle], [sin_angle, cos_angle]]) rotated_points = rotation_matrix @ data_points.T return rotated_points.T
[docs] def cartesian_to_polar_matrix( cartesian_matrix, tip=(61, 7), r_max=107, angle=0.79, interpolation="nearest" ): """ Function that converts a timeseries of a cartesian cone to a polar representation that is more compatible with CNN's/action selection. Args: - cartesian_matrix (2d array): (rows, cols) matrix containing time sequence of image_sc data. - tip (tuple, optional): coordinates (in indices) of the tip of the cone. Defaults to (61, 7). - r_max (int, optional): expected radius of the cone. Defaults to 107. - angle (float, optional): expected angle of the cone, will be used as (-angle, angle). Defaults to 0.79. - interpolation (str, optional): can be [nearest, linear, cubic]. Defaults to 'nearest'. Returns: polar_matrix (2d array): polar conversion of the input. """ rows, cols = cartesian_matrix.shape center_x, center_y = tip # Create cartesian coordinates of the image data x = np.linspace(-center_x, cols - center_x - 1, cols) y = np.linspace(-center_y, rows - center_y - 1, rows) x, y = np.meshgrid(x, y) # Flatten the grid and values data_points = np.column_stack((x.ravel(), y.ravel())) data_points = rotate_coordinates(data_points, -90) data_values = cartesian_matrix.ravel() # Define new points to sample from in the region of the data. # R_max and Theta are found manually. R_max differs from the number of rows in EchoNet! r = np.linspace(0, r_max, rows) theta = np.linspace(-angle, angle, cols) r, theta = np.meshgrid(r, theta) x_polar = r * np.cos(theta) y_polar = r * np.sin(theta) new_points = np.column_stack((x_polar.ravel(), y_polar.ravel())) # Interpolate and reshape to 2D matrix polar_values = griddata( data_points, data_values, new_points, method=interpolation, fill_value=0 ) polar_matrix = np.rot90(polar_values.reshape(cols, rows), k=-1) return polar_matrix
[docs] def find_split_for_file(file_dict, target_file): """ Locate which split contains a given filename. Parameters: file_dict (dict): Mapping from split name (e.g., "train", "val", "test", "rejected") to an iterable of filenames. target_file (str): Filename to search for within the split lists. Returns: str: The split name that contains `target_file`, or `"rejected"` if the file is not found. """ for split, files in file_dict.items(): if target_file in files: return split log.warning(f"File {target_file} not found in any split, defaulting to rejected.") return "rejected"
[docs] def count_init(shared_counter): """ Initialize the module-level shared counter used by worker processes. Parameters: shared_counter (multiprocessing.Value): A process-shared integer Value that will be assigned to the module-global COUNTER for coordinated counting across processes. """ global COUNTER COUNTER = shared_counter
[docs] class H5Processor: """ Stores a few variables and paths to allow for hyperthreading. """ def __init__( self, path_out_h5, num_val=500, num_test=500, range_from=(0, 255), range_to=(-60, 0), splits=None, ): self.path_out_h5 = Path(path_out_h5) self.num_val = num_val self.num_test = num_test self.range_from = range_from self.range_to = range_to self.splits = splits self._process_range = (0, 1) # Ensure train, val, test, rejected paths exist for folder in ["train", "val", "test", "rejected"]: (self.path_out_h5 / folder).mkdir(parents=True, exist_ok=True) def _translate(self, data): """Translate the data from the processing range to final range.""" return translate(data, self._process_range, self.range_to)
[docs] def get_split(self, hdf5_file: str, sequence): """ Determine the dataset split label for a given file and its image sequence. This method checks acceptance based on the first frame of `sequence`. If explicit splits were provided to the processor, it returns the split found for `hdf5_file` (and asserts that the acceptance result matches the split). If no explicit splits are provided, rejected sequences are labeled `"rejected"`. Accepted sequences increment a shared counter and are assigned `"val"`, `"test"`, or `"train"` according to the processor's `num_val` and `num_test` quotas. Args: hdf5_file (str): Filename or identifier used to look up an existing split when splits are provided. sequence (array-like): Time-ordered sequence of images; the first frame is used for acceptance checking. Returns: str: One of `"train"`, `"val"`, `"test"`, or `"rejected"` indicating the assigned split. """ # Always check acceptance accepted = accept_shape(sequence[0]) # Previous split if self.splits is not None: split = find_split_for_file(self.splits, hdf5_file) assert accepted == (split != "rejected"), "Rejection mismatch" return split # New split if not accepted: return "rejected" # Increment the hyperthreading counter # Note that some threads will start on subsequent splits # while others are still processing with COUNTER.get_lock(): COUNTER.value += 1 n = COUNTER.value # Determine the split if n <= self.num_val: return "val" elif n <= self.num_val + self.num_test: return "test" else: return "train"
[docs] def validate_split_copy(self, split_file): """ Validate that a generated split YAML matches the original splits provided to the processor. Reads the YAML at `split_file` and compares its `train`, `val`, `test`, and `rejected` lists (or other split keys present in `self.splits`) against `self.splits`; logs confirmation when a split matches and logs which entries are missing or extra when they differ. If the processor was not initialized with `splits`, validation is skipped and a message is logged. Args: split_file (str or os.PathLike): Path to the YAML file containing the generated dataset splits. """ if self.splits is not None: # Read the split_file and ensure contents of the train, val and split match with open(split_file, "r") as f: new_splits = yaml.safe_load(f) for split in self.splits.keys(): if set(new_splits[split]) == set(self.splits[split]): log.info(f"Split {split} copied correctly.") else: # Log which entry is missing or extra in the split_file missing = set(self.splits[split]) - set(new_splits[split]) extra = set(new_splits[split]) - set(self.splits[split]) if missing: log.warning(f"New dataset split {split} is missing entries: {missing}") if extra: log.warning(f"New dataset split {split} has extra entries: {extra}") else: log.info( "Processor not initialized with a split, not validating if the split was copied." )
[docs] def __call__(self, avi_file): """ Convert a single AVI file into a zea dataset entry. Loads the AVI, validates and rescales pixel ranges, applies segmentation, assigns a data split (train/val/test/rejected), converts accepted frames to polar coordinates. Constructs and returns the zea dataset descriptor used by generate_zea_dataset; the descriptor always includes `path`, `image_sc`, `probe_name`, and `description`, and includes `image` when the file is accepted. Args: avi_file (pathlib.Path): Path to the source .avi file to process. Returns: dict: The value returned by generate_zea_dataset containing the dataset entry for the processed file. """ hdf5_file = avi_file.stem + ".hdf5" sequence = load_avi(avi_file) assert sequence.min() >= self.range_from[0], f"{sequence.min()} < {self.range_from[0]}" assert sequence.max() <= self.range_from[1], f"{sequence.max()} > {self.range_from[1]}" # Translate to [0, 1] sequence = translate(sequence, self.range_from, self._process_range) sequence = segment(sequence, number_erasing=0, min_clip=0) split = self.get_split(hdf5_file, sequence) accepted = split != "rejected" out_h5 = self.path_out_h5 / split / hdf5_file polar_im_set = [] for _, im in enumerate(sequence): if not accepted: continue polar_im = cartesian_to_polar_matrix(im, interpolation="cubic") polar_im = np.clip(polar_im, *self._process_range) polar_im_set.append(polar_im) if accepted: polar_im_set = np.stack(polar_im_set, axis=0) # Check the ranges assert sequence.min() >= self._process_range[0], sequence.min() assert sequence.max() <= self._process_range[1], sequence.max() zea_dataset = { "path": out_h5, "image_sc": self._translate(sequence), "probe_name": "generic", "description": "EchoNet dataset converted to zea format", } if accepted: zea_dataset["image"] = self._translate(polar_im_set) return generate_zea_dataset(**zea_dataset)
[docs] def convert_echonet(args): """ Convert an EchoNet dataset into zea files, organizing results into train/val/test/rejected splits. Args: args (argparse.Namespace): An object with the following attributes. - src (str|Path): Path to the source archive or directory containing .avi files. Will be unzipped if needed. - dst (str|Path): Destination directory for generated zea files per-split subdirectories (train, val, test, rejected) and a split.yaml are created or updated. - split_path (str|Path|None): If provided, must contain a split.yaml to reproduce an existing split; function asserts the file exists. - no_hyperthreading (bool): When false, processing uses a ProcessPoolExecutor with a shared counter; when true, processing runs sequentially. Note: - May unzip the source into a working directory. - Writes zea files into dst. - Writes a split.yaml into dst summarizing produced files per split. - Logs progress and validation results. - Asserts that split.yaml exists at split_path when split reproduction is requested. """ # Check if unzip is needed src = unzip(args.src, "echonet") if args.split_path is not None: # Reproduce a previous split... yaml_file = Path(args.split_path) / "split.yaml" assert yaml_file.exists(), f"File {yaml_file} does not exist." splits = {"train": None, "val": None, "test": None, "rejected": None} with open(yaml_file, "r") as f: splits = yaml.safe_load(f) log.info(f"Processor initialized with train-val-test split from {yaml_file}.") else: splits = None # List the files that have an entry in path_out_h5 already files_done = [] for _, _, filenames in os.walk(args.dst): for filename in filenames: files_done.append(filename.replace(".hdf5", "")) # List all files of echonet and exclude those already processed path_in = Path(src) h5_files = path_in.glob("*.avi") h5_files = [file for file in h5_files if file.stem not in files_done] log.info(f"Files left to process: {len(h5_files)}") # Run the processor processor = H5Processor(path_out_h5=args.dst, splits=splits) log.info("Starting the conversion process.") if not args.no_hyperthreading: shared_counter = Value("i", 0) with ProcessPoolExecutor(initializer=count_init, initargs=(shared_counter,)) as executor: futures = [executor.submit(processor, file) for file in h5_files] for future in tqdm(as_completed(futures), total=len(futures)): try: future.result() except Exception: log.warning("Task raised an exception") else: # Initialize global variable for counting count_init(Value("i", 0)) for file in tqdm(h5_files): processor(file) log.info("All tasks are completed.") # Write to yaml split files full_list = {} for split in ["train", "val", "test", "rejected"]: split_dir = Path(args.dst) / split # Get only files (skip directories) file_list = [f.name for f in split_dir.iterdir() if f.is_file()] full_list[split] = file_list with open(Path(args.dst) / "split.yaml", "w") as f: yaml.dump(full_list, f) # Validate that the split was copied correctly processor.validate_split_copy(Path(args.dst) / "split.yaml")