Source code for zea.data.data_format

"""
Functions to write and validate datasets in the zea format.
"""

import inspect
from dataclasses import dataclass
from pathlib import Path

import h5py
import numpy as np
from keras.utils import pad_sequences

from zea import log
from zea.data.file import File, validate_file
from zea.internal.checks import _DATA_TYPES
from zea.internal.utils import first_not_none_item


[docs] @dataclass class DatasetElement: """Class to store a dataset element with a name, data, description and unit. Used to supply additional dataset elements to the generate_zea_dataset function.""" # The name of the dataset. This will be the key in the group. dataset_name: str # The data to store in the dataset. data: np.ndarray description: str unit: str # The group name to store the dataset under. This can be a nested group, e.g. # "lens/profiles" group_name: str = ""
[docs] def generate_example_dataset( path, add_optional_fields=False, add_optional_dtypes=False, n_frames=2, n_ax=2048, n_el=128, n_tx=11, n_ch=1, sound_speed=1540, center_frequency=7e6, sampling_frequency=40e6, grid_size_z=512, grid_size_x=512, ): """Generates an example dataset that contains all the necessary fields. Note: This dataset does not contain actual data, but is filled with random values. Args: path (str): The path to write the dataset to. add_optional_fields (bool, optional): Whether to add optional fields to the dataset. Defaults to False. add_optional_dtypes (bool, optional): Whether to add optional dtypes to the dataset. Defaults to False. """ # creating some fake raw and image data raw_data = np.ones((n_frames, n_tx, n_ax, n_el, n_ch)) # image data is in dB image = np.ones((n_frames, grid_size_z, grid_size_x)) * -40 # creating some fake scan parameters t0_delays = np.zeros((n_tx, n_el), dtype=np.float32) tx_apodizations = np.zeros((n_tx, n_el), dtype=np.float32) probe_geometry = np.zeros((n_el, 3), dtype=np.float32) probe_geometry[:, 0] = np.linspace(-0.02, 0.02, n_el) initial_times = np.zeros((n_tx,), dtype=np.float32) if add_optional_fields: focus_distances = np.ones((n_tx,), dtype=np.float32) * np.inf tx_apodizations = np.zeros((n_tx, n_el), dtype=np.float32) polar_angles = np.zeros((n_tx,), dtype=np.float32) azimuth_angles = np.zeros((n_tx,), dtype=np.float32) else: focus_distances = None tx_apodizations = None polar_angles = None azimuth_angles = None if add_optional_dtypes: aligned_data = np.ones((n_frames, n_tx, n_ax, n_el, n_ch)) envelope_data = np.ones((n_frames, grid_size_z, grid_size_x)) beamformed_data = np.ones((n_frames, grid_size_z, grid_size_x, n_ch)) image_sc = np.ones_like(image) else: aligned_data = None envelope_data = None beamformed_data = None image_sc = None generate_zea_dataset( path, raw_data=raw_data, aligned_data=aligned_data, envelope_data=envelope_data, beamformed_data=beamformed_data, image=image, image_sc=image_sc, probe_geometry=probe_geometry, sampling_frequency=sampling_frequency, center_frequency=center_frequency, initial_times=initial_times, t0_delays=t0_delays, sound_speed=sound_speed, tx_apodizations=tx_apodizations, probe_name="generic", focus_distances=focus_distances, polar_angles=polar_angles, azimuth_angles=azimuth_angles, additional_elements=_generate_example_dataset_elements(), description="This is an example dataset generated by zea", )
def _generate_example_dataset_elements() -> list[DatasetElement]: """Generates a list of example DatasetElement objects to be used as additional elements in the generate_zea_dataset function. Returns: list: A list of DatasetElement objects. """ example_elements = [ DatasetElement( dataset_name="temperature", data=np.array(42), description="The temperature during the measurement", unit="unitless", ), DatasetElement( dataset_name="lens_profile", data=np.random.rand(100), description="An example lens profile", unit="mm", group_name="lens", ), DatasetElement( dataset_name="lens_material", data=np.array(["material1", "material2", "material3"], dtype=h5py.string_dtype()), description="An example lens material list", unit="unitless", group_name="lens", ), ] return example_elements
[docs] def validate_input_data(raw_data, aligned_data, envelope_data, beamformed_data, image, image_sc): """ Validates input data for generate_zea_dataset Args: raw_data (np.ndarray): The raw data of the ultrasound measurement of shape (n_frames, n_tx, n_ax, n_el, n_ch). aligned_data (np.ndarray): The aligned data of the ultrasound measurement of shape (n_frames, n_tx, n_ax, n_el, n_ch). envelope_data (np.ndarray): The envelope data of the ultrasound measurement of shape (n_frames, grid_size_z, grid_size_x). beamformed_data (np.ndarray): The beamformed data of the ultrasound measurement of shape (n_frames, grid_size_z, grid_size_x). image (np.ndarray): The ultrasound images to be saved of shape (n_frames, grid_size_z, grid_size_x). image_sc (np.ndarray): The scan converted ultrasound images to be saved of shape (n_frames, output_size_z, output_size_x). """ assert ( raw_data is not None or aligned_data is not None or envelope_data is not None or beamformed_data is not None or image is not None or image_sc is not None ), f"At least one of the data types {_DATA_TYPES} must be specified."
# specific checks for each data type are done in validate_file def _write_datasets( dataset, data_group_name="data", scan_group_name="scan", raw_data=None, aligned_data=None, envelope_data=None, beamformed_data=None, image=None, image_sc=None, n_ax=None, n_el=None, n_tx=None, n_ch=None, n_frames=None, sound_speed=None, probe_geometry=None, sampling_frequency=None, center_frequency=None, demodulation_frequency=None, initial_times=None, t0_delays=None, tx_apodizations=None, focus_distances=None, transmit_origins=None, polar_angles=None, azimuth_angles=None, bandwidth_percent=None, time_to_next_transmit=None, tgc_gain_curve=None, element_width=None, tx_waveform_indices=None, waveforms_one_way=None, waveforms_two_way=None, additional_elements=None, cast_to_float=True, enable_compression=True, **kwargs, ): if kwargs: raise ValueError(f"Unknown arguments: {list(kwargs.keys())}") def _convert_datatype(x, astype=np.float32): if cast_to_float: return x.astype(astype) if x is not None else None else: return x def _first_not_none_shape(arr, axis): data = first_not_none_item(arr) return data.shape[axis] if data is not None else None def _add_dataset(group_name: str, name: str, data: np.ndarray, description: str, unit: str): """Adds a dataset to the given group with a description and unit. If data is None, the dataset is not added.""" if data is None: return data = np.asarray(data) # Create the group if it does not exist if group_name not in dataset: group = dataset.create_group(group_name) else: group = dataset[group_name] dataset_is_scalar = np.isscalar(data) or data.ndim == 0 compression = "gzip" if enable_compression and not dataset_is_scalar else None new_dataset = group.create_dataset(name, data=data, compression=compression) new_dataset.attrs["description"] = description new_dataset.attrs["unit"] = unit # Write data group data_group = dataset.create_group(data_group_name) data_group.attrs["description"] = "This group contains the data." if n_frames is None: n_frames = first_not_none_item( [raw_data, aligned_data, envelope_data, beamformed_data, image, image_sc] ).shape[0] if n_tx is None: n_tx = _first_not_none_shape([raw_data, aligned_data], axis=1) if n_ax is None: n_ax = _first_not_none_shape([raw_data, aligned_data, beamformed_data], axis=-3) if n_el is None: n_el = _first_not_none_shape([raw_data], axis=-2) if n_ch is None: n_ch = _first_not_none_shape([raw_data, aligned_data, beamformed_data], axis=-1) if n_tx is None: n_tx = _first_not_none_shape( [t0_delays, focus_distances, polar_angles, transmit_origins], axis=0 ) if n_el is None: n_el = _first_not_none_shape([t0_delays], axis=1) if n_el is None: n_el = _first_not_none_shape([probe_geometry], axis=0) if n_tx is None: n_tx = 1 _add_dataset( group_name=data_group_name, name="raw_data", data=_convert_datatype(raw_data), description="The raw_data of shape (n_frames, n_tx, n_ax, n_el, n_ch).", unit="unitless", ) _add_dataset( group_name=data_group_name, name="aligned_data", data=_convert_datatype(aligned_data), description="The aligned_data of shape (n_frames, n_tx, n_ax, n_el, n_ch).", unit="unitless", ) _add_dataset( group_name=data_group_name, name="envelope_data", data=_convert_datatype(envelope_data), description="The envelope_data of shape (n_frames, grid_size_z, grid_size_x).", unit="unitless", ) _add_dataset( group_name=data_group_name, name="beamformed_data", data=_convert_datatype(beamformed_data), description="The beamformed_data of shape (n_frames, grid_size_z, grid_size_x).", unit="unitless", ) _add_dataset( group_name=data_group_name, name="image", data=_convert_datatype(image), unit="unitless", description="The images of shape [n_frames, grid_size_z, grid_size_x]", ) _add_dataset( group_name=data_group_name, name="image_sc", data=_convert_datatype(image_sc), unit="unitless", description=("The scan converted images of shape [n_frames, output_size_z, output_size_x]"), ) # Write scan group scan_group = dataset.create_group(scan_group_name) scan_group.attrs["description"] = "This group contains the scan parameters." _add_dataset( group_name=scan_group_name, name="n_ax", data=n_ax, description="The number of axial samples.", unit="unitless", ) _add_dataset( group_name=scan_group_name, name="n_el", data=n_el, description="The number of elements in the probe.", unit="unitless", ) _add_dataset( group_name=scan_group_name, name="n_tx", data=n_tx, description="The number of transmits per frame.", unit="unitless", ) _add_dataset( group_name=scan_group_name, name="n_ch", data=n_ch, description=("The number of channels. For RF data this is 1. For IQ data this is 2."), unit="unitless", ) _add_dataset( group_name=scan_group_name, name="n_frames", data=n_frames, description="The number of frames.", unit="unitless", ) _add_dataset( group_name=scan_group_name, name="sound_speed", data=sound_speed, description="The speed of sound in m/s", unit="m/s", ) _add_dataset( group_name=scan_group_name, name="probe_geometry", data=probe_geometry, description="The probe geometry of shape (n_el, 3).", unit="m", ) _add_dataset( group_name=scan_group_name, name="sampling_frequency", data=sampling_frequency, description="The sampling frequency in Hz.", unit="Hz", ) _add_dataset( group_name=scan_group_name, name="center_frequency", data=center_frequency, description="The center frequency of the transmit pulse in Hz.", unit="Hz", ) _add_dataset( group_name=scan_group_name, name="demodulation_frequency", data=demodulation_frequency, description="The frequency at which the data should be " "demodulated in Hz. (Usually the same as center_frequency, " "but different when doing harmonic imaging.)", unit="Hz", ) _add_dataset( group_name=scan_group_name, name="initial_times", data=initial_times, description=( "The times when the A/D converter starts sampling " "in seconds of shape (n_tx,). This is the time between the " "first element firing and the first recorded sample." ), unit="s", ) _add_dataset( group_name=scan_group_name, name="t0_delays", data=t0_delays, description="The t0_delays of shape (n_tx, n_el).", unit="s", ) _add_dataset( group_name=scan_group_name, name="tx_apodizations", data=tx_apodizations, description=( "The transmit delays for each element defining the" " wavefront in seconds of shape (n_tx, n_elem). This is" " the time at which each element fires shifted such that" " the first element fires at t=0." ), unit="unitless", ) _add_dataset( group_name=scan_group_name, name="focus_distances", data=focus_distances, description=( "The transmit focus distances in meters of " "shape (n_tx,). This is the distance from the origin point on the transducer to where " "the beam comes to focus. For planewaves this is set to Inf." ), unit="m", ) _add_dataset( group_name=scan_group_name, name="transmit_origins", data=transmit_origins, description=( "The transmit origins in meters of the transmit beams " "of shape (n_tx, 3). This is the (x, y, z) position " "from which the beam is transmitted." ), unit="m", ) _add_dataset( group_name=scan_group_name, name="polar_angles", data=polar_angles, description=("The polar angles of the transmit beams in radians of shape (n_tx,)."), unit="rad", ) _add_dataset( group_name=scan_group_name, name="azimuth_angles", data=azimuth_angles, description=("The azimuthal angles of the transmit beams in radians of shape (n_tx,)."), unit="rad", ) _add_dataset( group_name=scan_group_name, name="bandwidth_percent", data=bandwidth_percent, description=("The receive bandwidth of RF signal in percentage of center frequency."), unit="unitless", ) _add_dataset( group_name=scan_group_name, name="time_to_next_transmit", data=time_to_next_transmit, description=("The time between subsequent transmit events of shape (n_frames, n_tx)."), unit="s", ) _add_dataset( group_name=scan_group_name, name="tgc_gain_curve", data=tgc_gain_curve, description=( "The time-gain-compensation that was applied to every sample in the " "raw_data of shape (n_ax,). Divide by this curve to undo the TGC." ), unit="unitless", ) _add_dataset( group_name=scan_group_name, name="element_width", data=element_width, description="The width of the elements in the probe in meters.", unit="m", ) if tx_waveform_indices is not None and ( waveforms_one_way is not None or waveforms_two_way is not None ): _add_dataset( group_name=scan_group_name, name="tx_waveform_indices", data=tx_waveform_indices, description=( "Transmit indices for waveforms, indexing waveforms_one_way " "and waveforms_two_way. This indicates which transmit waveform was " "used for each transmit event." ), unit="-", ) if waveforms_one_way is not None: _add_dataset( group_name=scan_group_name, name="waveforms_one_way", data=pad_sequences(waveforms_one_way, dtype=np.float32, padding="post"), description=( "One-way waveform as simulated by the Verasonics system, " "sampled at 250MHz. This is the waveform after being filtered " "by the transducer bandwidth once." ), unit="V", ) if waveforms_two_way is not None: _add_dataset( group_name=scan_group_name, name="waveforms_two_way", data=pad_sequences(waveforms_two_way, dtype=np.float32, padding="post"), description=( "Two-way waveform as simulated by the Verasonics system, " "sampled at 250MHz. This is the waveform after being filtered " "by the transducer bandwidth twice." ), unit="V", ) # Add additional elements if additional_elements is not None: # Write scan group non_standard_elements_group_name = "non_standard_elements" non_standard_elements_group = dataset.create_group(non_standard_elements_group_name) non_standard_elements_group.attrs["description"] = ( "This group contains non-standard elements that can be added by the user." ) for element in additional_elements: group_name = non_standard_elements_group_name if element.group_name != "": group_name += f"/{element.group_name}" _add_dataset( group_name=group_name, name=element.dataset_name, data=element.data, description=element.description, unit=element.unit, )
[docs] def generate_zea_dataset( path, raw_data=None, aligned_data=None, envelope_data=None, beamformed_data=None, image=None, image_sc=None, probe_geometry=None, sampling_frequency=None, center_frequency=None, demodulation_frequency=None, initial_times=None, t0_delays=None, sound_speed=None, probe_name=None, description="No description was supplied", focus_distances=None, transmit_origins=None, polar_angles=None, azimuth_angles=None, tx_apodizations=None, bandwidth_percent=None, time_to_next_transmit=None, tgc_gain_curve=None, element_width=None, tx_waveform_indices=None, waveforms_one_way=None, waveforms_two_way=None, additional_elements=None, event_structure=False, cast_to_float=True, overwrite=False, enable_compression=True, ): """Generates a dataset in the zea format. Args: path (str): The path to write the dataset to. raw_data (np.ndarray): The raw data of the ultrasound measurement of shape (n_frames, n_tx, n_ax, n_el, n_ch). aligned_data (np.ndarray): The aligned data of the ultrasound measurement of shape (n_frames, n_tx, n_ax, n_el, n_ch). envelope_data (np.ndarray): The envelope data of the ultrasound measurement of shape (n_frames, grid_size_z, grid_size_x). beamformed_data (np.ndarray): The beamformed data of the ultrasound measurement of shape (n_frames, grid_size_z, grid_size_x, n_ch). image (np.ndarray): The ultrasound images to be saved of shape (n_frames, grid_size_z, grid_size_x). image_sc (np.ndarray): The scan converted ultrasound images to be saved of shape (n_frames, output_size_z, output_size_x). probe_geometry (np.ndarray): The probe geometry of shape (n_el, 3). sampling_frequency (float): The sampling frequency in Hz. center_frequency (float): The center frequency of the transmit pulse in Hz. demodulation_frequency (float): The demodulation frequency in Hz. initial_times (list): The times when the A/D converter starts sampling in seconds of shape (n_tx,). This is the time between the first element firing and the first recorded sample. t0_delays (np.ndarray): The t0_delays of shape (n_tx, n_el). sound_speed (float): The speed of sound in m/s. probe_name (str): The name of the probe. description (str): The description of the dataset. focus_distances (np.ndarray): The focus distances of shape (n_tx,). transmit_origins (np.ndarray): The transmit origins of shape (n_tx, 3). polar_angles (np.ndarray): The polar angles (radians) of shape (n_tx,). azimuth_angles (np.ndarray): The azimuth angles (radians) of shape (n_tx,). tx_apodizations (np.ndarray): The transmit delays for each element defining the wavefront in seconds of shape (n_tx, n_elem). This is the time between the first element firing and the last element firing. bandwidth_percent (float): The bandwidth of the transducer as a percentage of the center frequency. time_to_next_transmit (np.ndarray): The time between subsequent transmit events in s of shape (n_frames, n_tx). tgc_gain_curve (np.ndarray): The TGC gain that was applied to every sample in the raw_data of shape (n_ax). element_width (float): The width of the elements in the probe in meters of shape (n_tx,). tx_waveform_indices (np.ndarray): Transmit indices for waveforms, indexing waveforms_one_way and waveforms_two_way. This indicates which transmit waveform was used for each transmit event. waveforms_one_way (list): List of one-way waveforms as simulated by the Verasonics system, sampled at 250MHz. This is the waveform after being filtered by the transducer bandwidth once. Every element in the list is a 1D numpy array. waveforms_two_way (list): List of two-way waveforms as simulated by the Verasonics system, sampled at 250MHz. This is the waveform after being filtered by the transducer bandwidth twice. Every element in the list is a 1D numpy array. additional_elements (List[DatasetElement]): A list of additional dataset elements to be added to the dataset. Each element should be a DatasetElement object. The additional elements are added under the scan group. event_structure (bool): Whether to write the dataset with an event structure. In that case all data should be lists with the same length (number of events). The data will be stored under event_i/data and event_i/scan for each event i. Instead of just a single data and scan group. cast_to_float (bool): Whether to store data as float32. You may want to set this to False if storing images. overwrite (bool): Whether to overwrite the file if it already exists. Defaults to False. enable_compression (bool): Whether to enable gzip compression for datasets. Defaults to True. Compression reduces disk space at the cost of increased write time. """ # check if all args are lists if isinstance(probe_name, list): # all names in probe_name list should be the same assert len(set(probe_name)) == 1, "Probe names for all events should be the same" data_and_parameters = { "raw_data": raw_data, "aligned_data": aligned_data, "envelope_data": envelope_data, "beamformed_data": beamformed_data, "image": image, "image_sc": image_sc, "probe_geometry": probe_geometry, "sampling_frequency": sampling_frequency, "center_frequency": center_frequency, "demodulation_frequency": demodulation_frequency, "initial_times": initial_times, "t0_delays": t0_delays, "sound_speed": sound_speed, "probe_name": probe_name, "description": description, "focus_distances": focus_distances, "transmit_origins": transmit_origins, "polar_angles": polar_angles, "azimuth_angles": azimuth_angles, "tx_apodizations": tx_apodizations, "bandwidth_percent": bandwidth_percent, "time_to_next_transmit": time_to_next_transmit, "tgc_gain_curve": tgc_gain_curve, "element_width": element_width, "tx_waveform_indices": tx_waveform_indices, "waveforms_one_way": waveforms_one_way, "waveforms_two_way": waveforms_two_way, "additional_elements": additional_elements, } # make sure input arguments of func is same length as data_and_parameters # except `path` and `event_structure` arguments and ofcourse `data_and_parameters` itself assert ( len(data_and_parameters) == len(inspect.signature(generate_zea_dataset).parameters) - 5 ), ( "All arguments should be put in data_and_parameters except " "`path`, `event_structure`, `cast_to_float`, `overwrite`, and `enable_compression` " "arguments." ) if event_structure: for argument, argument_value in data_and_parameters.items(): _num_events = None if argument_value is not None: assert isinstance(argument_value, list), ( f"{argument} should be a list when event_structure is set to True." ) num_events = len(argument_value) if _num_events is not None: assert num_events == _num_events, ( "All arguments should have the same number of events." ) _num_events = num_events assert len(set(probe_name)) == 1, "Probe names for all events should be the same" log.info( f"Event structure is set to True. Writing dataset with event " f"structure (found {len(probe_name)} events)." ) num_events = len(probe_name) probe_name = probe_name[0] description = description[0] assert isinstance(probe_name, str), "The probe name must be a string." assert isinstance(description, str), "The description must be a string." assert isinstance(event_structure, bool), "The event_structure must be a boolean." validate_input_data( raw_data=raw_data, aligned_data=aligned_data, envelope_data=envelope_data, beamformed_data=beamformed_data, image=image, image_sc=image_sc, ) # Convert path to Path object path = Path(path) if path.exists() and not overwrite: raise FileExistsError(f"The file {path} already exists.") # Create the directory if it does not exist path.parent.mkdir(parents=True, exist_ok=True) with File(path, "w") as dataset: dataset.attrs["probe"] = probe_name dataset.attrs["description"] = description dataset.attrs["event_structure"] = event_structure # remove probe and description from data_and_parameters data_and_parameters.pop("probe_name") data_and_parameters.pop("description") if event_structure: for i in range(num_events): _data_and_parameters = { k: v[i] for k, v in data_and_parameters.items() if v is not None } _write_datasets( dataset, data_group_name=f"event_{i}/data", scan_group_name=f"event_{i}/scan", cast_to_float=cast_to_float, enable_compression=enable_compression, **_data_and_parameters, ) else: _write_datasets( dataset, data_group_name="data", scan_group_name="scan", cast_to_float=cast_to_float, enable_compression=enable_compression, **data_and_parameters, ) validate_file(path) log.info(f"zea dataset written to {log.yellow(path)}")
[docs] def load_description(path): """Loads the description of a zea dataset. Args: path (str): The path to the zea dataset. Returns: str: The description of the dataset, or an empty string if not found. """ path = Path(path) with File(path, "r") as file: description = file.attrs.get("description", "") return description
[docs] def load_additional_elements(path): """Loads additional dataset elements from a zea dataset. Args: path (str): The path to the zea dataset. Returns: list: A list of DatasetElement objects. """ path = Path(path) with File(path, "r") as file: if "non_standard_elements" not in file: return [] additional_elements = _load_additional_elements_from_group(file, "non_standard_elements") return additional_elements
def _load_additional_elements_from_group(file, path): """Recursively loads additional dataset elements from a group.""" elements = [] for name, item in file[path].items(): if isinstance(item, h5py.Dataset): elements.append(_load_dataset_element_from_group(file, f"{path}/{name}")) elif isinstance(item, h5py.Group): elements.extend(_load_additional_elements_from_group(file, f"{path}/{name}")) return elements def _load_dataset_element_from_group(file, path): """Loads a specific dataset element from a group. Args: file (h5py.File): The HDF5 file object. path (str): The full path to the dataset element. e.g., "non_standard_elements/lens/lens_profile" Returns: DatasetElement: The loaded dataset element. """ dataset = file[path] description = dataset.attrs.get("description", "") unit = dataset.attrs.get("unit", "") data = dataset[()] path_parts = path.split("/") return DatasetElement( dataset_name=path_parts[-1], data=data, description=description, unit=unit, group_name="/".join(path_parts[1:-1]), )