Source code for zea.models.utils
"""Utilities for models"""
import keras
[docs]
class LossTrackerWrapper:
"""A wrapper for Keras Mean metrics to track multiple loss values."""
def __init__(self, prefix):
"""
Initialize the loss tracker wrapper.
Args:
prefix (str): Prefix to use for the loss name. For example "n_loss" or "i_loss".
"""
self.prefix = prefix
self.trackers = {}
[docs]
def update_state(self, loss_value):
"""
Update the tracker(s) with a loss value.
If loss_value is a dict, then for each key a separate tracker is
created (if not already created) and updated. The tracker's name will
be <prefix>_<key>. If loss_value is not a dict, then a default tracker
with name <prefix> is updated.
Args:
loss_value: A tensor or a dictionary mapping field names to tensors.
"""
if isinstance(loss_value, dict):
for key, value in loss_value.items():
tracker_name = f"{self.prefix}_{key}"
if tracker_name not in self.trackers:
self.trackers[tracker_name] = keras.metrics.Mean(name=tracker_name)
self.trackers[tracker_name].update_state(value)
else:
if self.prefix not in self.trackers:
self.trackers[self.prefix] = keras.metrics.Mean(name=self.prefix)
self.trackers[self.prefix].update_state(loss_value)
[docs]
def result(self):
"""
Return a dictionary with the current average results.
"""
results = {}
for _, tracker in self.trackers.items():
# Use the tracker's name (e.g. "n_loss_a") if available
results[tracker.name] = tracker.result()
return results
[docs]
def reset_state(self):
"""
Reset all the internal trackers.
"""
for tracker in self.trackers.values():
tracker.reset_state()
def __iter__(self):
return iter(self.trackers.values())