"""
Functions to write and validate datasets in the zea format.
"""
import inspect
from dataclasses import dataclass
from pathlib import Path
import h5py
import numpy as np
from keras.utils import pad_sequences
from zea import log
from zea.data.file import File, validate_file
from zea.internal.checks import _DATA_TYPES
from zea.internal.utils import first_not_none_item
[docs]
@dataclass
class DatasetElement:
"""Class to store a dataset element with a name, data, description and unit. Used to
supply additional dataset elements to the generate_zea_dataset function."""
# The name of the dataset. This will be the key in the group.
dataset_name: str
# The data to store in the dataset.
data: np.ndarray
description: str
unit: str
# The group name to store the dataset under. This can be a nested group, e.g.
# "lens/profiles"
group_name: str = ""
[docs]
def generate_example_dataset(
path,
add_optional_fields=False,
add_optional_dtypes=False,
n_frames=2,
n_ax=2048,
n_el=128,
n_tx=11,
n_ch=1,
sound_speed=1540,
center_frequency=7e6,
sampling_frequency=40e6,
grid_size_z=512,
grid_size_x=512,
):
"""Generates an example dataset that contains all the necessary fields.
Note: This dataset does not contain actual data, but is filled with random
values.
Args:
path (str): The path to write the dataset to.
add_optional_fields (bool, optional): Whether to add optional fields to
the dataset. Defaults to False.
add_optional_dtypes (bool, optional): Whether to add optional dtypes to
the dataset. Defaults to False.
"""
# creating some fake raw and image data
raw_data = np.ones((n_frames, n_tx, n_ax, n_el, n_ch))
# image data is in dB
image = np.ones((n_frames, grid_size_z, grid_size_x)) * -40
# creating some fake scan parameters
t0_delays = np.zeros((n_tx, n_el), dtype=np.float32)
tx_apodizations = np.zeros((n_tx, n_el), dtype=np.float32)
probe_geometry = np.zeros((n_el, 3), dtype=np.float32)
probe_geometry[:, 0] = np.linspace(-0.02, 0.02, n_el)
initial_times = np.zeros((n_tx,), dtype=np.float32)
if add_optional_fields:
focus_distances = np.ones((n_tx,), dtype=np.float32) * np.inf
tx_apodizations = np.zeros((n_tx, n_el), dtype=np.float32)
polar_angles = np.zeros((n_tx,), dtype=np.float32)
azimuth_angles = np.zeros((n_tx,), dtype=np.float32)
else:
focus_distances = None
tx_apodizations = None
polar_angles = None
azimuth_angles = None
if add_optional_dtypes:
aligned_data = np.ones((n_frames, n_tx, n_ax, n_el, n_ch))
envelope_data = np.ones((n_frames, grid_size_z, grid_size_x))
beamformed_data = np.ones((n_frames, grid_size_z, grid_size_x, n_ch))
image_sc = np.ones_like(image)
else:
aligned_data = None
envelope_data = None
beamformed_data = None
image_sc = None
generate_zea_dataset(
path,
raw_data=raw_data,
aligned_data=aligned_data,
envelope_data=envelope_data,
beamformed_data=beamformed_data,
image=image,
image_sc=image_sc,
probe_geometry=probe_geometry,
sampling_frequency=sampling_frequency,
center_frequency=center_frequency,
initial_times=initial_times,
t0_delays=t0_delays,
sound_speed=sound_speed,
tx_apodizations=tx_apodizations,
probe_name="generic",
focus_distances=focus_distances,
polar_angles=polar_angles,
azimuth_angles=azimuth_angles,
additional_elements=_generate_example_dataset_elements(),
description="This is an example dataset generated by zea",
)
def _generate_example_dataset_elements() -> list[DatasetElement]:
"""Generates a list of example DatasetElement objects to be used as additional
elements in the generate_zea_dataset function.
Returns:
list: A list of DatasetElement objects.
"""
example_elements = [
DatasetElement(
dataset_name="temperature",
data=np.array(42),
description="The temperature during the measurement",
unit="unitless",
),
DatasetElement(
dataset_name="lens_profile",
data=np.random.rand(100),
description="An example lens profile",
unit="mm",
group_name="lens",
),
DatasetElement(
dataset_name="lens_material",
data=np.array(["material1", "material2", "material3"], dtype=h5py.string_dtype()),
description="An example lens material list",
unit="unitless",
group_name="lens",
),
]
return example_elements
# specific checks for each data type are done in validate_file
def _write_datasets(
dataset,
data_group_name="data",
scan_group_name="scan",
raw_data=None,
aligned_data=None,
envelope_data=None,
beamformed_data=None,
image=None,
image_sc=None,
n_ax=None,
n_el=None,
n_tx=None,
n_ch=None,
n_frames=None,
sound_speed=None,
probe_geometry=None,
sampling_frequency=None,
center_frequency=None,
demodulation_frequency=None,
initial_times=None,
t0_delays=None,
tx_apodizations=None,
focus_distances=None,
transmit_origins=None,
polar_angles=None,
azimuth_angles=None,
bandwidth_percent=None,
time_to_next_transmit=None,
tgc_gain_curve=None,
element_width=None,
tx_waveform_indices=None,
waveforms_one_way=None,
waveforms_two_way=None,
additional_elements=None,
cast_to_float=True,
enable_compression=True,
**kwargs,
):
if kwargs:
raise ValueError(f"Unknown arguments: {list(kwargs.keys())}")
def _convert_datatype(x, astype=np.float32):
if cast_to_float:
return x.astype(astype) if x is not None else None
else:
return x
def _first_not_none_shape(arr, axis):
data = first_not_none_item(arr)
return data.shape[axis] if data is not None else None
def _add_dataset(group_name: str, name: str, data: np.ndarray, description: str, unit: str):
"""Adds a dataset to the given group with a description and unit.
If data is None, the dataset is not added."""
if data is None:
return
data = np.asarray(data)
# Create the group if it does not exist
if group_name not in dataset:
group = dataset.create_group(group_name)
else:
group = dataset[group_name]
dataset_is_scalar = np.isscalar(data) or data.ndim == 0
compression = "gzip" if enable_compression and not dataset_is_scalar else None
new_dataset = group.create_dataset(name, data=data, compression=compression)
new_dataset.attrs["description"] = description
new_dataset.attrs["unit"] = unit
# Write data group
data_group = dataset.create_group(data_group_name)
data_group.attrs["description"] = "This group contains the data."
if n_frames is None:
n_frames = first_not_none_item(
[raw_data, aligned_data, envelope_data, beamformed_data, image, image_sc]
).shape[0]
if n_tx is None:
n_tx = _first_not_none_shape([raw_data, aligned_data], axis=1)
if n_ax is None:
n_ax = _first_not_none_shape([raw_data, aligned_data, beamformed_data], axis=-3)
if n_el is None:
n_el = _first_not_none_shape([raw_data], axis=-2)
if n_ch is None:
n_ch = _first_not_none_shape([raw_data, aligned_data, beamformed_data], axis=-1)
if n_tx is None:
n_tx = _first_not_none_shape(
[t0_delays, focus_distances, polar_angles, transmit_origins], axis=0
)
if n_el is None:
n_el = _first_not_none_shape([t0_delays], axis=1)
if n_el is None:
n_el = _first_not_none_shape([probe_geometry], axis=0)
if n_tx is None:
n_tx = 1
_add_dataset(
group_name=data_group_name,
name="raw_data",
data=_convert_datatype(raw_data),
description="The raw_data of shape (n_frames, n_tx, n_ax, n_el, n_ch).",
unit="unitless",
)
_add_dataset(
group_name=data_group_name,
name="aligned_data",
data=_convert_datatype(aligned_data),
description="The aligned_data of shape (n_frames, n_tx, n_ax, n_el, n_ch).",
unit="unitless",
)
_add_dataset(
group_name=data_group_name,
name="envelope_data",
data=_convert_datatype(envelope_data),
description="The envelope_data of shape (n_frames, grid_size_z, grid_size_x).",
unit="unitless",
)
_add_dataset(
group_name=data_group_name,
name="beamformed_data",
data=_convert_datatype(beamformed_data),
description="The beamformed_data of shape (n_frames, grid_size_z, grid_size_x).",
unit="unitless",
)
_add_dataset(
group_name=data_group_name,
name="image",
data=_convert_datatype(image),
unit="unitless",
description="The images of shape [n_frames, grid_size_z, grid_size_x]",
)
_add_dataset(
group_name=data_group_name,
name="image_sc",
data=_convert_datatype(image_sc),
unit="unitless",
description=("The scan converted images of shape [n_frames, output_size_z, output_size_x]"),
)
# Write scan group
scan_group = dataset.create_group(scan_group_name)
scan_group.attrs["description"] = "This group contains the scan parameters."
_add_dataset(
group_name=scan_group_name,
name="n_ax",
data=n_ax,
description="The number of axial samples.",
unit="unitless",
)
_add_dataset(
group_name=scan_group_name,
name="n_el",
data=n_el,
description="The number of elements in the probe.",
unit="unitless",
)
_add_dataset(
group_name=scan_group_name,
name="n_tx",
data=n_tx,
description="The number of transmits per frame.",
unit="unitless",
)
_add_dataset(
group_name=scan_group_name,
name="n_ch",
data=n_ch,
description=("The number of channels. For RF data this is 1. For IQ data this is 2."),
unit="unitless",
)
_add_dataset(
group_name=scan_group_name,
name="n_frames",
data=n_frames,
description="The number of frames.",
unit="unitless",
)
_add_dataset(
group_name=scan_group_name,
name="sound_speed",
data=sound_speed,
description="The speed of sound in m/s",
unit="m/s",
)
_add_dataset(
group_name=scan_group_name,
name="probe_geometry",
data=probe_geometry,
description="The probe geometry of shape (n_el, 3).",
unit="m",
)
_add_dataset(
group_name=scan_group_name,
name="sampling_frequency",
data=sampling_frequency,
description="The sampling frequency in Hz.",
unit="Hz",
)
_add_dataset(
group_name=scan_group_name,
name="center_frequency",
data=center_frequency,
description="The center frequency of the transmit pulse in Hz.",
unit="Hz",
)
_add_dataset(
group_name=scan_group_name,
name="demodulation_frequency",
data=demodulation_frequency,
description="The frequency at which the data should be "
"demodulated in Hz. (Usually the same as center_frequency, "
"but different when doing harmonic imaging.)",
unit="Hz",
)
_add_dataset(
group_name=scan_group_name,
name="initial_times",
data=initial_times,
description=(
"The times when the A/D converter starts sampling "
"in seconds of shape (n_tx,). This is the time between the "
"first element firing and the first recorded sample."
),
unit="s",
)
_add_dataset(
group_name=scan_group_name,
name="t0_delays",
data=t0_delays,
description="The t0_delays of shape (n_tx, n_el).",
unit="s",
)
_add_dataset(
group_name=scan_group_name,
name="tx_apodizations",
data=tx_apodizations,
description=(
"The transmit delays for each element defining the"
" wavefront in seconds of shape (n_tx, n_elem). This is"
" the time at which each element fires shifted such that"
" the first element fires at t=0."
),
unit="unitless",
)
_add_dataset(
group_name=scan_group_name,
name="focus_distances",
data=focus_distances,
description=(
"The transmit focus distances in meters of "
"shape (n_tx,). This is the distance from the origin point on the transducer to where "
"the beam comes to focus. For planewaves this is set to Inf."
),
unit="m",
)
_add_dataset(
group_name=scan_group_name,
name="transmit_origins",
data=transmit_origins,
description=(
"The transmit origins in meters of the transmit beams "
"of shape (n_tx, 3). This is the (x, y, z) position "
"from which the beam is transmitted."
),
unit="m",
)
_add_dataset(
group_name=scan_group_name,
name="polar_angles",
data=polar_angles,
description=("The polar angles of the transmit beams in radians of shape (n_tx,)."),
unit="rad",
)
_add_dataset(
group_name=scan_group_name,
name="azimuth_angles",
data=azimuth_angles,
description=("The azimuthal angles of the transmit beams in radians of shape (n_tx,)."),
unit="rad",
)
_add_dataset(
group_name=scan_group_name,
name="bandwidth_percent",
data=bandwidth_percent,
description=("The receive bandwidth of RF signal in percentage of center frequency."),
unit="unitless",
)
_add_dataset(
group_name=scan_group_name,
name="time_to_next_transmit",
data=time_to_next_transmit,
description=("The time between subsequent transmit events of shape (n_frames, n_tx)."),
unit="s",
)
_add_dataset(
group_name=scan_group_name,
name="tgc_gain_curve",
data=tgc_gain_curve,
description=(
"The time-gain-compensation that was applied to every sample in the "
"raw_data of shape (n_ax,). Divide by this curve to undo the TGC."
),
unit="unitless",
)
_add_dataset(
group_name=scan_group_name,
name="element_width",
data=element_width,
description="The width of the elements in the probe in meters.",
unit="m",
)
if tx_waveform_indices is not None and (
waveforms_one_way is not None or waveforms_two_way is not None
):
_add_dataset(
group_name=scan_group_name,
name="tx_waveform_indices",
data=tx_waveform_indices,
description=(
"Transmit indices for waveforms, indexing waveforms_one_way "
"and waveforms_two_way. This indicates which transmit waveform was "
"used for each transmit event."
),
unit="-",
)
if waveforms_one_way is not None:
_add_dataset(
group_name=scan_group_name,
name="waveforms_one_way",
data=pad_sequences(waveforms_one_way, dtype=np.float32, padding="post"),
description=(
"One-way waveform as simulated by the Verasonics system, "
"sampled at 250MHz. This is the waveform after being filtered "
"by the transducer bandwidth once."
),
unit="V",
)
if waveforms_two_way is not None:
_add_dataset(
group_name=scan_group_name,
name="waveforms_two_way",
data=pad_sequences(waveforms_two_way, dtype=np.float32, padding="post"),
description=(
"Two-way waveform as simulated by the Verasonics system, "
"sampled at 250MHz. This is the waveform after being filtered "
"by the transducer bandwidth twice."
),
unit="V",
)
# Add additional elements
if additional_elements is not None:
# Write scan group
non_standard_elements_group_name = "non_standard_elements"
non_standard_elements_group = dataset.create_group(non_standard_elements_group_name)
non_standard_elements_group.attrs["description"] = (
"This group contains non-standard elements that can be added by the user."
)
for element in additional_elements:
group_name = non_standard_elements_group_name
if element.group_name != "":
group_name += f"/{element.group_name}"
_add_dataset(
group_name=group_name,
name=element.dataset_name,
data=element.data,
description=element.description,
unit=element.unit,
)
[docs]
def generate_zea_dataset(
path,
raw_data=None,
aligned_data=None,
envelope_data=None,
beamformed_data=None,
image=None,
image_sc=None,
probe_geometry=None,
sampling_frequency=None,
center_frequency=None,
demodulation_frequency=None,
initial_times=None,
t0_delays=None,
sound_speed=None,
probe_name=None,
description="No description was supplied",
focus_distances=None,
transmit_origins=None,
polar_angles=None,
azimuth_angles=None,
tx_apodizations=None,
bandwidth_percent=None,
time_to_next_transmit=None,
tgc_gain_curve=None,
element_width=None,
tx_waveform_indices=None,
waveforms_one_way=None,
waveforms_two_way=None,
additional_elements=None,
event_structure=False,
cast_to_float=True,
overwrite=False,
enable_compression=True,
):
"""Generates a dataset in the zea format.
Args:
path (str): The path to write the dataset to.
raw_data (np.ndarray): The raw data of the ultrasound measurement of
shape (n_frames, n_tx, n_ax, n_el, n_ch).
aligned_data (np.ndarray): The aligned data of the ultrasound measurement of
shape (n_frames, n_tx, n_ax, n_el, n_ch).
envelope_data (np.ndarray): The envelope data of the ultrasound measurement of
shape (n_frames, grid_size_z, grid_size_x).
beamformed_data (np.ndarray): The beamformed data of the ultrasound measurement of
shape (n_frames, grid_size_z, grid_size_x, n_ch).
image (np.ndarray): The ultrasound images to be saved
of shape (n_frames, grid_size_z, grid_size_x).
image_sc (np.ndarray): The scan converted ultrasound images to be saved
of shape (n_frames, output_size_z, output_size_x).
probe_geometry (np.ndarray): The probe geometry of shape (n_el, 3).
sampling_frequency (float): The sampling frequency in Hz.
center_frequency (float): The center frequency of the transmit pulse in Hz.
demodulation_frequency (float): The demodulation frequency in Hz.
initial_times (list): The times when the A/D converter starts sampling
in seconds of shape (n_tx,). This is the time between the first element
firing and the first recorded sample.
t0_delays (np.ndarray): The t0_delays of shape (n_tx, n_el).
sound_speed (float): The speed of sound in m/s.
probe_name (str): The name of the probe.
description (str): The description of the dataset.
focus_distances (np.ndarray): The focus distances of shape (n_tx,).
transmit_origins (np.ndarray): The transmit origins of shape (n_tx, 3).
polar_angles (np.ndarray): The polar angles (radians) of shape (n_tx,).
azimuth_angles (np.ndarray): The azimuth angles (radians) of shape (n_tx,).
tx_apodizations (np.ndarray): The transmit delays for each element defining
the wavefront in seconds of shape (n_tx, n_elem).
This is the time between the first element firing and the last element firing.
bandwidth_percent (float): The bandwidth of the transducer as a
percentage of the center frequency.
time_to_next_transmit (np.ndarray): The time between subsequent transmit events in s
of shape (n_frames, n_tx).
tgc_gain_curve (np.ndarray): The TGC gain that was applied to every sample in the
raw_data of shape (n_ax).
element_width (float): The width of the elements in the probe in meters of
shape (n_tx,).
tx_waveform_indices (np.ndarray): Transmit indices for waveforms, indexing
waveforms_one_way and waveforms_two_way. This indicates which transmit
waveform was used for each transmit event.
waveforms_one_way (list): List of one-way waveforms as simulated by the Verasonics
system, sampled at 250MHz. This is the waveform after being filtered by the
transducer bandwidth once. Every element in the list is a 1D numpy array.
waveforms_two_way (list): List of two-way waveforms as simulated by the Verasonics
system, sampled at 250MHz. This is the waveform after being filtered by the
transducer bandwidth twice. Every element in the list is a 1D numpy array.
additional_elements (List[DatasetElement]): A list of additional dataset
elements to be added to the dataset. Each element should be a DatasetElement
object. The additional elements are added under the scan group.
event_structure (bool): Whether to write the dataset with an event structure.
In that case all data should be lists with the same length (number of events).
The data will be stored under event_i/data and event_i/scan for each event i.
Instead of just a single data and scan group.
cast_to_float (bool): Whether to store data as float32. You may want to set this
to False if storing images.
overwrite (bool): Whether to overwrite the file if it already exists. Defaults to False.
enable_compression (bool): Whether to enable gzip compression for datasets.
Defaults to True. Compression reduces disk space at the cost of increased
write time.
"""
# check if all args are lists
if isinstance(probe_name, list):
# all names in probe_name list should be the same
assert len(set(probe_name)) == 1, "Probe names for all events should be the same"
data_and_parameters = {
"raw_data": raw_data,
"aligned_data": aligned_data,
"envelope_data": envelope_data,
"beamformed_data": beamformed_data,
"image": image,
"image_sc": image_sc,
"probe_geometry": probe_geometry,
"sampling_frequency": sampling_frequency,
"center_frequency": center_frequency,
"demodulation_frequency": demodulation_frequency,
"initial_times": initial_times,
"t0_delays": t0_delays,
"sound_speed": sound_speed,
"probe_name": probe_name,
"description": description,
"focus_distances": focus_distances,
"transmit_origins": transmit_origins,
"polar_angles": polar_angles,
"azimuth_angles": azimuth_angles,
"tx_apodizations": tx_apodizations,
"bandwidth_percent": bandwidth_percent,
"time_to_next_transmit": time_to_next_transmit,
"tgc_gain_curve": tgc_gain_curve,
"element_width": element_width,
"tx_waveform_indices": tx_waveform_indices,
"waveforms_one_way": waveforms_one_way,
"waveforms_two_way": waveforms_two_way,
"additional_elements": additional_elements,
}
# make sure input arguments of func is same length as data_and_parameters
# except `path` and `event_structure` arguments and ofcourse `data_and_parameters` itself
assert (
len(data_and_parameters) == len(inspect.signature(generate_zea_dataset).parameters) - 5
), (
"All arguments should be put in data_and_parameters except "
"`path`, `event_structure`, `cast_to_float`, `overwrite`, and `enable_compression` "
"arguments."
)
if event_structure:
for argument, argument_value in data_and_parameters.items():
_num_events = None
if argument_value is not None:
assert isinstance(argument_value, list), (
f"{argument} should be a list when event_structure is set to True."
)
num_events = len(argument_value)
if _num_events is not None:
assert num_events == _num_events, (
"All arguments should have the same number of events."
)
_num_events = num_events
assert len(set(probe_name)) == 1, "Probe names for all events should be the same"
log.info(
f"Event structure is set to True. Writing dataset with event "
f"structure (found {len(probe_name)} events)."
)
num_events = len(probe_name)
probe_name = probe_name[0]
description = description[0]
assert isinstance(probe_name, str), "The probe name must be a string."
assert isinstance(description, str), "The description must be a string."
assert isinstance(event_structure, bool), "The event_structure must be a boolean."
validate_input_data(
raw_data=raw_data,
aligned_data=aligned_data,
envelope_data=envelope_data,
beamformed_data=beamformed_data,
image=image,
image_sc=image_sc,
)
# Convert path to Path object
path = Path(path)
if path.exists() and not overwrite:
raise FileExistsError(f"The file {path} already exists.")
# Create the directory if it does not exist
path.parent.mkdir(parents=True, exist_ok=True)
with File(path, "w") as dataset:
dataset.attrs["probe"] = probe_name
dataset.attrs["description"] = description
dataset.attrs["event_structure"] = event_structure
# remove probe and description from data_and_parameters
data_and_parameters.pop("probe_name")
data_and_parameters.pop("description")
if event_structure:
for i in range(num_events):
_data_and_parameters = {
k: v[i] for k, v in data_and_parameters.items() if v is not None
}
_write_datasets(
dataset,
data_group_name=f"event_{i}/data",
scan_group_name=f"event_{i}/scan",
cast_to_float=cast_to_float,
enable_compression=enable_compression,
**_data_and_parameters,
)
else:
_write_datasets(
dataset,
data_group_name="data",
scan_group_name="scan",
cast_to_float=cast_to_float,
enable_compression=enable_compression,
**data_and_parameters,
)
validate_file(path)
log.info(f"zea dataset written to {log.yellow(path)}")
[docs]
def load_description(path):
"""Loads the description of a zea dataset.
Args:
path (str): The path to the zea dataset.
Returns:
str: The description of the dataset, or an empty string if not found.
"""
path = Path(path)
with File(path, "r") as file:
description = file.attrs.get("description", "")
return description
[docs]
def load_additional_elements(path):
"""Loads additional dataset elements from a zea dataset.
Args:
path (str): The path to the zea dataset.
Returns:
list: A list of DatasetElement objects.
"""
path = Path(path)
with File(path, "r") as file:
if "non_standard_elements" not in file:
return []
additional_elements = _load_additional_elements_from_group(file, "non_standard_elements")
return additional_elements
def _load_additional_elements_from_group(file, path):
"""Recursively loads additional dataset elements from a group."""
elements = []
for name, item in file[path].items():
if isinstance(item, h5py.Dataset):
elements.append(_load_dataset_element_from_group(file, f"{path}/{name}"))
elif isinstance(item, h5py.Group):
elements.extend(_load_additional_elements_from_group(file, f"{path}/{name}"))
return elements
def _load_dataset_element_from_group(file, path):
"""Loads a specific dataset element from a group.
Args:
file (h5py.File): The HDF5 file object.
path (str): The full path to the dataset element.
e.g., "non_standard_elements/lens/lens_profile"
Returns:
DatasetElement: The loaded dataset element.
"""
dataset = file[path]
description = dataset.attrs.get("description", "")
unit = dataset.attrs.get("unit", "")
data = dataset[()]
path_parts = path.split("/")
return DatasetElement(
dataset_name=path_parts[-1],
data=data,
description=description,
unit=unit,
group_name="/".join(path_parts[1:-1]),
)