"""Base model class for all zea Keras models.
This module provides the `BaseModel` class for all zea Keras models.
"""
import importlib
import keras
from keras.src.saving.serialization_lib import record_object_after_deserialization
from zea import log
from zea.internal.core import classproperty
from zea.models.preset_utils import builtin_presets, get_preset_loader, get_preset_saver
[docs]
class BaseModel(keras.models.Model):
"""Base class for all zea Keras models.
A ``BaseModel`` is the basic model for zea.
"""
[docs]
@classmethod
def from_config(cls, config):
"""Create a model instance from a configuration dictionary.
The default ``from_config()`` for functional models will return a
vanilla ``keras.Model``. This override ensures a subclass instance is returned.
Args:
config (dict): Configuration dictionary.
Returns:
BaseModel: An instance of the model subclass.
"""
return cls(**config)
@classproperty
def presets(cls):
"""List built-in presets for a ``BaseModel`` subclass.
Returns:
dict: Dictionary of available built-in presets.
"""
return builtin_presets(cls)
[docs]
@classmethod
def from_preset(
cls,
preset,
load_weights=True,
**kwargs,
):
"""Instantiate a model from a preset.
A preset is a directory of configs, weights, and other file assets used
to save and load a pre-trained model. The ``preset`` can be passed as one of:
1. a built-in preset identifier like ``'bert_base_en'``
2. a Kaggle Models handle like ``'kaggle://user/bert/keras/bert_base_en'``
3. a Hugging Face handle like ``'hf://user/bert_base_en'``
4. a path to a local preset directory like ``'./bert_base_en'``
This constructor can be called in one of two ways: either from the base
class like ``keras_hub.models.Backbone.from_preset()``, or from
a model class like ``keras_hub.models.GemmaBackbone.from_preset()``.
If calling from the base class, the subclass of the returning object
will be inferred from the config in the preset directory.
For any ``Backbone`` subclass, you can run ``cls.presets.keys()`` to list
all built-in presets available on the class.
Args:
preset (str): A built-in preset identifier, a Kaggle Models
handle, a Hugging Face handle, or a path to a local directory.
load_weights (bool): If ``True``, the weights will be loaded into the
model architecture. If ``False``, the weights will be randomly
initialized.
**kwargs: Additional keyword arguments.
Example:
.. code-block:: python
# Load a Gemma backbone with pre-trained weights.
model = keras_hub.models.Backbone.from_preset(
"gemma_2b_en",
)
# Load a Bert backbone with a pre-trained config and random weights.
model = keras_hub.models.Backbone.from_preset(
"bert_base_en",
load_weights=False,
)
Returns:
BaseModel: The loaded model instance.
"""
loader = get_preset_loader(preset)
loader_cls = loader.check_model_class()
if cls != loader_cls:
full_cls_name = f"{cls.__module__}.{cls.__name__}"
full_loader_cls_name = f"{loader_cls.__module__}.{loader_cls.__name__}"
if issubclass(cls, loader_cls):
log.warning(
f"The preset '{preset}' is for model class '{full_loader_cls_name}', but you "
f"are calling from a subclass '{full_cls_name}', so the returned object will "
f"be of type '{full_cls_name}'."
)
elif issubclass(loader_cls, cls):
log.warning(
f"The preset '{preset}' is for model class '{full_loader_cls_name}', "
f"which is a subclass of the calling class '{full_cls_name}', "
f"so the returned object will be of type '{full_cls_name}'."
)
else:
raise ValueError(
f"The preset '{preset}' is for model class '{full_loader_cls_name}', "
f"which is not compatible with the calling class '{full_cls_name}'. "
f"Please call '{full_loader_cls_name}.from_preset()' instead."
)
return loader.load_model(cls, load_weights, **kwargs)
[docs]
def save_to_preset(self, preset_dir):
"""Save backbone to a preset directory.
Args:
preset_dir: The path to the local model preset directory.
"""
saver = get_preset_saver(preset_dir)
saver.save_model(self)
[docs]
def deserialize_zea_object(config, cls=None):
"""Retrieve the object by deserializing the config dict.
Need to borrow this function from keras and customize a bit to allow
deserialization of custom (zea) objects. See the original function here:
`keras.utils.deserialize_keras_object()`. As from the following keras
PR did not work on none Keras objects anymore:
- https://github.com/keras-team/keras/pull/20751
Args:
config (dict): The configuration dictionary
Returns:
obj (Object): The deserialized object
"""
class_name = config["class_name"]
inner_config = config["config"] or {}
if cls is None:
module = config.get("module", None)
registered_name = config.get("registered_name", class_name)
cls = _retrieve_class(module, registered_name, config)
if not hasattr(cls, "from_config"):
raise TypeError(
f"Unable to reconstruct an instance of '{class_name}' because "
f"the class is missing a `from_config()` method. "
f"Full object config: {config}"
)
try:
instance = cls.from_config(inner_config)
except TypeError as e:
raise TypeError(
f"{cls} could not be deserialized properly. Please"
" ensure that components that are Python object"
" instances (layers, models, etc.) returned by"
" `get_config()` are explicitly deserialized in the"
" model's `from_config()` method."
f"\n\nconfig={config}.\n\nException encountered: {e}"
) from e
build_config = config.get("build_config", None)
if build_config and not instance.built:
instance.build_from_config(build_config)
instance.built = True
compile_config = config.get("compile_config", None)
if compile_config:
instance.compile_from_config(compile_config)
instance.compiled = True
if "shared_object_id" in config:
record_object_after_deserialization(instance, config["shared_object_id"])
return instance
def _retrieve_class(module, class_name, config):
# Attempt to retrieve the class object given the `module`
# and `class_name`. Import the module, find the class.
assert module is not None, f"Module must be specified in the config.json, but got {module}."
package = module.split(".", maxsplit=1)[0]
if package in {"zea", "zea-addons"}:
try:
mod = importlib.import_module(module)
obj = vars(mod).get(class_name, None)
if obj is not None:
return obj
except ModuleNotFoundError as exc:
raise TypeError(
f"Could not deserialize class '{class_name}' because "
f"its parent module {module} cannot be imported. "
f"Full object config: {config}"
) from exc
raise TypeError(
f"Could not locate class '{class_name}'. "
"Make sure custom classes are decorated with "
"`@zea.registry.model_registry().`"
f"Full object config: {config}"
)