Source code for zea.data.utils
"""Utility functions for zea datasets."""
import json
from pathlib import Path
from keras import ops
[docs]
class ZeaJSONEncoder(json.JSONEncoder):
"""Wrapper for json.dumps to encode range and slice objects.
Example:
>>> import json
>>> from zea.data.utils import ZeaJSONEncoder
>>> json.dumps(range(10), cls=ZeaJSONEncoder)
'{"__type__": "range", "start": 0, "stop": 10, "step": 1}'
Note:
Probably you would use the `zea.data.dataloader.json_dumps()`
function instead of using this class directly.
"""
[docs]
def default(self, o):
if isinstance(o, range):
return {
"__type__": "range",
"start": o.start,
"stop": o.stop,
"step": o.step,
}
if isinstance(o, slice):
return {
"__type__": "slice",
"start": o.start,
"stop": o.stop,
"step": o.step,
}
if isinstance(o, Path):
return str(o)
return super().default(o)
[docs]
def json_dumps(obj):
"""Used to serialize objects that contain range and slice objects.
Args:
obj: object to serialize (most likely a dictionary).
Returns:
str: serialized object (json string).
"""
return json.dumps(obj, cls=ZeaJSONEncoder)
[docs]
def json_loads(obj):
"""Used to deserialize objects that contain range and slice objects.
Args:
obj: object to deserialize (most likely a json string).
Returns:
object: deserialized object (dictionary).
"""
return json.loads(obj, object_hook=_zea_datasets_json_decoder)
[docs]
def decode_file_info(file_info):
"""Decode file info from a json string.
A batch of H5Generator can return a list of file_info that are json strings.
This function decodes the json strings and returns a list of dictionaries
with the information, namely:
- full_path: full path to the file
- file_name: file name
- indices: indices used to extract the image from the file
"""
if file_info.ndim == 0:
file_info = [file_info]
decoded_info = []
for info in file_info:
info = ops.convert_to_numpy(info)[()].decode("utf-8")
decoded_info.append(json_loads(info))
return decoded_info
def _zea_datasets_json_decoder(dct):
"""Wrapper for json.loads to decode range and slice objects."""
if "__type__" in dct:
if dct["__type__"] == "range":
return range(dct["start"], dct["stop"], dct["step"])
if dct["__type__"] == "slice":
return slice(dct["start"], dct["stop"], dct["step"])
return dct