Source code for zea.models.generative
"""Generative models for zea."""
import abc
from zea.models.base import BaseModel
[docs]
class GenerativeModel(abc.ABC):
"""Abstract base class for generative models."""
[docs]
def fit(self, data, **kwargs):
"""Fit the model to the data.
Args:
data: The data to fit the model to.
**kwargs: Additional arguments to pass to the fitting procedure.
"""
raise NotImplementedError("fit() must be implemented in subclasses.")
[docs]
def sample(self, n_samples=1, **kwargs):
r"""Draw samples $x \sim p(x)$ from the model.
Args:
n_samples: Number of samples to generate.
**kwargs: Additional arguments to pass to the sampling procedure.
Returns:
Samples $x$ from the model distribution $p(x)$.
"""
raise NotImplementedError("sample() must be implemented in subclasses.")
[docs]
def posterior_sample(self, measurements, n_samples=1, **kwargs):
r"""Draw samples $z \sim p(z \mid x)$ from the posterior given measurements.
Args:
measurements: The measurements $x$ to condition the posterior on.
n_samples: Number of posterior samples to generate. This will add
an additional dimension to the output. For instance,
if `measurements` has shape `(batch_size, ...)`, the output will
have shape `(batch_size, n_samples, ...)`.
**kwargs: Additional arguments to pass to the sampling procedure.
Returns:
Samples $z$ from the posterior $p(z \mid x)$.
"""
raise NotImplementedError("posterior_sample() must be implemented in subclasses.")
[docs]
def log_density(self, data, **kwargs):
r"""Compute the log-density $\log p(x)$ of the data under the model.
Args:
data: The data $x$ to compute the log-density for.
**kwargs: Additional arguments.
Returns:
Log-density $\log p(x)$ of the data.
"""
raise NotImplementedError("log_density() must be implemented in subclasses.")
[docs]
class DeepGenerativeModel(BaseModel, GenerativeModel):
"""Base class for deep generative models.
Inherits from both GenerativeModel and BaseModel to combine
generative capabilities with Keras model functionality.
"""
def __init__(self, name="deep_generative_model", **kwargs):
"""Initialize a deep generative model.
Args:
name: Name of the model.
**kwargs: Additional arguments to pass to BaseModel.
"""
BaseModel.__init__(self, name=name, **kwargs)