Source code for zea.backend.tensorflow.utils.utils
"""Utility functions for zea tensorflow modules."""
import numpy as np
import tensorflow as tf
[docs]
class DotDict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
[docs]
def tf_snapshot(obj) -> dict:
"""Returns a snapshot of the object parameters as a dictionary of tensors.
Returns:
dict: The scan parameters as a dictionary of tensors.
"""
EXCEPTIONS = ["angles", "_angles"]
snapshot = DotDict()
for key in dir(obj):
if key[0] != "_" and key not in EXCEPTIONS:
value = getattr(obj, key)
if isinstance(value, (np.ndarray, int, float, list)):
# if data is of double precision, convert to float32
if isinstance(value, np.ndarray) and value.dtype == np.float64:
dtype = tf.float32
else:
dtype = None
snapshot[key] = tf.convert_to_tensor(value, dtype=dtype)
return snapshot