"""
Script to convert the EchoNet-LVH database to zea format.
Each video is cropped so that the scan cone is centered
without padding, such that it can be converted to polar domain.
.. note::
This cropping requires first computing scan cone parameters
using :mod:`zea.data.convert.echonetlvh.precompute_crop`, which
are then passed to this script.
For more information about the dataset, resort to the following links:
- The original dataset can be found at `this link <https://stanfordaimi.azurewebsites.net/datasets/5b7fcc28-579c-4285-8b72-e4238eac7bd1>`_.
"""
import csv
import os
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
import jax.numpy as jnp
import numpy as np
from jax import jit, vmap
from tqdm import tqdm
from zea import log
from zea.data import generate_zea_dataset
from zea.data.convert.echonet import H5Processor
from zea.data.convert.echonetlvh.precompute_crop import precompute_cone_parameters
from zea.data.convert.utils import load_avi, unzip
from zea.display import cartesian_to_polar_matrix
from zea.func.tensor import translate
[docs]
def overwrite_splits(source_dir, rejection_path=None):
"""
Overwrite MeasurementsList.csv splits based on manual_rejections.txt or another
txt file specifying which hashes to reject.
Args:
source_dir: Source directory containing MeasurementsList.csv and manual_rejections.txt
rejection_path: Path to the rejection txt file. If None, defaults to ./manual_rejections.txt
Returns:
None
"""
current_dir = os.path.dirname(os.path.abspath(__file__))
if rejection_path is None:
rejection_path = os.path.join(current_dir, "manual_rejections.txt")
expected_num_rejections = 278
else:
# unknown number of rejections for custom rejection file.
# NOTE: this is used for testing, where we want to use a dummy rejections file
expected_num_rejections = -1
try:
with open(rejection_path) as f:
rejected_hashes = [line.strip() for line in f]
except FileNotFoundError:
log.warning(f"{rejection_path} not found, skipping rejections.")
return
csv_path = Path(source_dir) / "MeasurementsList.csv"
temp_path = Path(source_dir) / "MeasurementsList_temp.csv"
try:
rejection_counter = 0
with (
csv_path.open("r", newline="", encoding="utf-8") as infile,
temp_path.open("w", encoding="utf-8", newline="") as outfile,
):
reader = csv.DictReader(infile)
writer = csv.DictWriter(outfile, fieldnames=reader.fieldnames)
writer.writeheader()
for row in reader:
if row["HashedFileName"] in rejected_hashes:
row["split"] = "rejected"
rejection_counter += 1
writer.writerow(row)
if expected_num_rejections != -1:
assert rejection_counter == expected_num_rejections, (
f"Expected {expected_num_rejections} rejections, but applied only {rejection_counter}."
)
except FileNotFoundError:
log.warning(f"{csv_path} not found, skipping rejections.")
return
temp_path.replace(csv_path)
log.info(f"Overwritten {rejection_counter}/278 rejections to {csv_path}")
return
[docs]
def load_splits(source_dir):
"""
Load splits from MeasurementsList.csv and return avi filenames
Args:
source_dir: Source directory containing MeasurementsList.csv
Returns:
Dictionary with keys 'train', 'val', 'test', 'rejected' and values as lists of avi filenames
"""
csv_path = Path(source_dir) / "MeasurementsList.csv"
splits = {"train": [], "val": [], "test": [], "rejected": []}
with open(csv_path, newline="", encoding="utf-8") as csvfile:
reader = csv.DictReader(csvfile)
file_split_map = {}
for row in reader:
filename = row["HashedFileName"]
split = row["split"]
file_split_map.setdefault(filename, split)
for filename, split in file_split_map.items():
splits[split].append(filename + ".avi")
return splits
[docs]
def find_avi_file(source_dir, hashed_filename, batch=None):
"""
Find AVI file in the specified batch directory or any batch if not specified.
Args:
source_dir: Source directory containing BatchX subdirectories
hashed_filename: Hashed filename (with or without .avi extension)
batch: Specific batch directory to search in (e.g., "Batch2"), or None to search all batches
Returns:
Path to the AVI file if found, else None
"""
# If filename already has .avi extension, strip it
if hashed_filename.endswith(".avi"):
hashed_filename = hashed_filename[:-4]
if batch:
batch_dir = Path(source_dir) / batch
avi_path = batch_dir / f"{hashed_filename}.avi"
if avi_path.exists():
return avi_path
return None
else:
for batch_dir in Path(source_dir).glob("Batch*"):
avi_path = batch_dir / f"{hashed_filename}.avi"
if avi_path.exists():
return avi_path
return None
[docs]
def load_cone_parameters(csv_path):
"""
Load cone parameters from CSV file into a dictionary.
Args:
csv_path: Path to the CSV file containing cone parameters
Returns:
Dictionary mapping avi_filename to cone parameters
"""
cone_params = {}
with open(csv_path, "r", encoding="utf-8") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
if row["status"] == "success":
# Convert string values to appropriate types
params = {}
for key, value in row.items():
if key in ("avi_filename", "status"):
params[key] = value
elif key == "apex_above_image":
params[key] = value.lower() == "true"
elif value is not None and value != "":
params[key] = float(value)
else:
params[key] = None
cone_params[row["avi_filename"]] = params
return cone_params
[docs]
def crop_frame_with_params(frame, cone_params):
"""
Crop a single frame using predetermined cone parameters.
Args:
frame: Input frame as numpy array
cone_params: Dictionary containing cropping parameters
Returns:
Cropped and padded frame
"""
crop_left = int(cone_params["crop_left"])
crop_right = int(cone_params["crop_right"])
crop_top = int(cone_params["crop_top"])
crop_bottom = int(cone_params["crop_bottom"])
# Handle negative crop_top
if crop_top < 0:
cropped = frame[0:crop_bottom, crop_left:crop_right]
# Add top padding
top_padding = -crop_top
top_pad = jnp.zeros((top_padding, cropped.shape[1]), dtype=cropped.dtype)
cropped = jnp.concatenate([top_pad, cropped], axis=0)
else:
cropped = frame[crop_top:crop_bottom, crop_left:crop_right]
# Apply horizontal centering
apex_x_in_crop = cone_params["apex_x"] - crop_left
cropped_height, cropped_width = cropped.shape
target_center_x = cropped_width / 2
left_padding_needed = target_center_x - apex_x_in_crop
left_padding = max(0, int(left_padding_needed))
right_padding = max(0, int(-left_padding_needed))
if left_padding > 0 or right_padding > 0:
if left_padding > 0:
left_pad = jnp.zeros((cropped_height, left_padding), dtype=cropped.dtype)
cropped = jnp.concatenate([left_pad, cropped], axis=1)
if right_padding > 0:
right_pad = jnp.zeros((cropped_height, right_padding), dtype=cropped.dtype)
cropped = jnp.concatenate([cropped, right_pad], axis=1)
return cropped
[docs]
def crop_sequence_with_params(sequence, cone_params):
"""
Apply cropping to a sequence of frames using predetermined parameters.
Args:
sequence: Input sequence as numpy array of shape (frames, height, width)
cone_params: Dictionary containing cropping parameters
Returns:
Cropped and padded sequence
"""
crop_sequence = vmap(lambda frame: crop_frame_with_params(frame, cone_params))
return crop_sequence(sequence)
[docs]
class LVHProcessor(H5Processor):
"""Modified H5Processor for EchoNet-LVH dataset."""
def __init__(self, *args, cone_params=None, **kwargs):
super().__init__(*args, **kwargs)
# Store the pre-computed cone parameters
self.cart2pol_jit = jit(cartesian_to_polar_matrix)
self.cart2pol_batched = vmap(
(lambda matrix, angle: self.cart2pol_jit(matrix, angle=angle)), in_axes=(0, None)
) # map over sequence of images, keep the angle fixed since it's constant across a sequence
self.cone_parameters = cone_params or {}
[docs]
def get_split(self, avi_file: str, sequence):
"""
Get the split (train/val/test) for a given AVI file.
Args:
avi_file: Path to the AVI file
sequence: Video sequence (unused)
Returns:
String indicating the split ('train', 'val', or 'test')
"""
# Extract base filename without extension
filename = Path(avi_file).stem + ".avi"
for split, files in self.splits.items():
if filename in files:
return split
raise UserWarning("Unknown split for file: " + filename)
[docs]
def __call__(self, avi_file):
"""Takes a single avi_file and generates a zea dataset
Args:
avi_file: String or path to avi_file to be processed
Returns:
zea dataset
"""
avi_filename = Path(avi_file).stem + ".avi"
sequence_np = load_avi(avi_file)
sequence_processed = jnp.array(sequence_np)
sequence_processed = translate(sequence_processed, self.range_from, self._process_range)
# Get pre-computed cone parameters for this file
cone_params = self.cone_parameters.get(avi_filename)
if cone_params is not None:
# Apply pre-computed cropping parameters
sequence_processed = crop_sequence_with_params(sequence_processed, cone_params)
else:
raise UserWarning(f"No cone parameters for {avi_filename}")
split = self.get_split(avi_file, sequence_processed)
out_h5 = self.path_out_h5 / split / (Path(avi_file).stem + ".hdf5")
angle = cone_params["opening_angle"] / 2 # angular field spans (-angle, +angle)
polar_im_set = self.cart2pol_batched(sequence_processed, angle)
sequence_processed = translate(sequence_processed, self._process_range, self.range_from)
sequence_processed_uint8 = jnp.asarray(jnp.floor(sequence_processed + 0.5), dtype=jnp.uint8)
del sequence_processed
polar_im_set = translate(polar_im_set, self._process_range, (0, 255))
polar_im_set_uint8 = jnp.asarray(jnp.floor(polar_im_set + 0.5), dtype=jnp.uint8)
del polar_im_set
if jnp.all(sequence_processed_uint8 == 0):
raise ValueError(f"Processed sequence is all zeros for file {avi_file}")
if jnp.all(polar_im_set_uint8 == 0):
raise ValueError(f"Polar sequence is all zeros for file {avi_file}")
zea_dataset = {
"path": out_h5,
"image_sc": sequence_processed_uint8,
"probe_name": "generic",
"description": "EchoNet-LVH dataset converted to zea format",
"image": polar_im_set_uint8,
"cast_to_float": False,
}
return generate_zea_dataset(**zea_dataset)
[docs]
def convert_measurements_csv(source_csv, output_csv, cone_params_csv=None):
"""Convert measurements CSV file with updated coordinates using cone parameters.
Args:
source_csv: Path to source CSV file
output_csv: Path to output CSV file
cone_params_csv: Path to CSV file with cone parameters
"""
try:
# Read the CSV file
with open(source_csv, newline="", encoding="utf-8") as csvfile:
reader = csv.DictReader(csvfile)
rows = list(reader)
fieldnames = reader.fieldnames
# Load cone parameters if available
cone_parameters = {}
if cone_params_csv and Path(cone_params_csv).exists():
cone_parameters = load_cone_parameters(cone_params_csv)
else:
log.warning("No cone parameters file found. Measurements will not be transformed.")
# Apply coordinate transformation and track skipped rows
transformed_rows = []
skipped_files = set()
for row in rows:
try:
avi_filename = row["HashedFileName"] + ".avi"
cone_params = cone_parameters.get(avi_filename, None)
transformed_row = transform_measurement_coordinates_with_cone_params(
row, cone_params
)
if transformed_row is not None:
transformed_rows.append(transformed_row)
else:
skipped_files.add(row["HashedFileName"])
except Exception as e:
log.error(f"Error processing row for file {row['HashedFileName']}: {str(e)}")
skipped_files.add(row["HashedFileName"])
# Save to new CSV file
if transformed_rows:
# Use keys from first row as fieldnames
out_fieldnames = list(transformed_rows[0].keys())
with open(output_csv, "w", newline="", encoding="utf-8") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=out_fieldnames)
writer.writeheader()
writer.writerows(transformed_rows)
else:
# Write header only if no rows
with open(output_csv, "w", newline="", encoding="utf-8") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
# Print summary
log.info("Conversion Summary:")
log.info(f"Total rows processed: {len(rows)}")
log.info(f"Rows successfully converted: {len(transformed_rows)}")
log.info(f"Rows skipped: {len(rows) - len(transformed_rows)}")
if skipped_files:
log.info("Skipped files:")
for filename in sorted(skipped_files):
log.info(f" - {filename}")
log.info(f"Converted measurements saved to {output_csv}")
except Exception as e:
log.error(f"Error processing CSV file: {str(e)}")
raise
def _process_file_worker(avi_file, dst, splits, cone_parameters, range_from, process_range):
"""
Function for a hyperthreading worker to process a single file.
Args:
avi_file: Path to the AVI file to process
dst: Destination directory for output
splits: Dictionary of splits
cone_parameters: Dictionary of cone parameters
range_from: Range from value for processing
process_range: Process range value for processing
Returns:
Result of processing the file
"""
# create a fresh processor inside the worker process
proc = LVHProcessor(path_out_h5=dst, splits=splits, cone_params=cone_parameters)
# if LVHProcessor needs range_from/_process_range set, set them here
proc.range_from = range_from
proc._process_range = process_range
return proc(avi_file)
[docs]
def convert_echonetlvh(args):
"""
Conversion script for the EchoNet-LVH dataset.
Unzips, overwrites splits if needed, precomputes cone parameters,
and converts images and/or measurements to zea format and saves dataset.
Is called with argparse arguments through zea/zea/data/convert/__main__.py
Args:
args (argparse.Namespace): Command-line arguments
"""
# Check if unzip is needed
src = unzip(args.src, "echonetlvh")
# Overwrite the splits if manual rejections are provided
if not args.no_rejection:
overwrite_splits(args.src, getattr(args, "rejection_path", None))
# Check that cone parameters exist
cone_params_csv = Path(args.dst) / "cone_parameters.csv"
if not cone_params_csv.exists():
precompute_cone_parameters(args)
# If no specific conversion is requested, convert both
if not (args.convert_measurements or args.convert_images):
args.convert_measurements = True
args.convert_images = True
# Convert images if requested
if args.convert_images:
source_path = Path(src)
splits = load_splits(source_path)
# Load precomputed cone parameters
cone_parameters = load_cone_parameters(cone_params_csv)
log.info(f"Loaded cone parameters for {len(cone_parameters)} files")
files_to_process = []
for split_files in splits.values():
for avi_filename in split_files:
# Strip .avi if present
base_filename = avi_filename[:-4] if avi_filename.endswith(".avi") else avi_filename
avi_file = find_avi_file(src, base_filename, batch=args.batch)
if avi_file:
files_to_process.append(avi_file)
else:
log.warning(
f"Warning: Could not find AVI file for {base_filename} in batch "
f"{args.batch if args.batch else 'any'}"
)
# List files that have already been processed
files_done = []
for _, _, filenames in os.walk(args.dst):
for filename in filenames:
if filename.endswith(".hdf5"):
files_done.append(filename.replace(".hdf5", ""))
# Filter out already processed files
files_to_process = [f for f in files_to_process if f.stem not in files_done]
# Limit files if max_files is specified
if args.max_files is not None:
files_to_process = files_to_process[: args.max_files]
log.info(f"Limited to processing {args.max_files} files due to max_files parameter")
log.info(f"Files left to process: {len(files_to_process)}")
# Initialize processor with splits and cone parameters
processor = LVHProcessor(path_out_h5=args.dst, splits=splits, cone_params=cone_parameters)
log.info("Starting the conversion process.")
for file in tqdm(files_to_process):
try:
processor(file)
except Exception as e:
log.error(f"Error processing {file}: {str(e)}")
log.info("All image conversion tasks are completed.")
# Convert measurements if requested
if args.convert_measurements:
source_path = Path(src)
measurements_csv = source_path / "MeasurementsList.csv"
if measurements_csv.exists():
output_csv = Path(args.dst) / "MeasurementsList.csv"
convert_measurements_csv(measurements_csv, output_csv, cone_params_csv)
else:
log.warning("MeasurementsList.csv not found in source directory")
log.info("All tasks are completed.")