Source code for zea.models.hvae

"""
Hierarchical Variational Auto-Encoder for image generation, posterior sampling and inference tasks.
To try this model, simply load one of the available presets:

.. doctest::

    >>> from zea.models.hvae import HierarchicalVAE

    >>> model = HierarchicalVAE.from_preset("hvae")  # doctest: +SKIP

.. important::
    This is a ``zea`` implementation of the model.
    For the original code, see `here <https://github.com/swpenninga/hvae>`_.

.. seealso::
    A tutorial notebook where this model is used:
    :doc:`../notebooks/models/hvae_model_example`.

"""

import pickle

from keras import ops

from zea.internal.registry import model_registry
from zea.models.generative import DeepGenerativeModel
from zea.models.hvae.model import VAE
from zea.models.hvae.utils import Parameters
from zea.models.preset_utils import get_preset_loader, register_presets
from zea.models.presets import hvae_presets

SUPPORTED_VERSIONS = [
    "lvh",
    "lvh_ur24",
    "lvh_ur16",
    "lvh_ur8",
    "lvh_ur4",
    "lvh_ge24",
    "lvh_ge16",
    "lvh_ge8",
    "lvh_ge4",
]


[docs] @model_registry(name="hvae") class HierarchicalVAE(DeepGenerativeModel): """ Hierarchical Variational Autoencoder (HVAE) model. The network as defined here is a snippet of the complete model at: https://github.com/swpenninga/hvae The lvh versions are trained on EchoNetLVH at 256x256 resolution with 3 channels. (video-frames as channel dimension) The ur(.) versions denote retraining with a UniformRandom agent with (.)/256 lines. Unlike the other models, this network is built when the weights are loaded. """ def __init__(self, name="hvae", version="lvh", **kwargs): """ Args: name (str): Name of the model. version (str): Version of the HVAE model to use. Supported versions are: "lvh", "lvh_ur24", "lvh_ur16", "lvh_ur8", "lvh_ur4", "lvh_ge24", "lvh_ge16", "lvh_ge8", "lvh_ge4". """ super().__init__(name, **kwargs) assert version in SUPPORTED_VERSIONS, ( f"Unsupported version '{version}' for HVAE model." f"Current supported versions are: {', '.join(SUPPORTED_VERSIONS)}." ) self.version = version self.network = None
[docs] def custom_load_weights(self, preset): """ Load the pretrained weights of the HVAE model from a preset. First builds the model architecture from args.pkl, then loads the weights into the model. """ loader = get_preset_loader(preset) args_file = loader.get_file("args.pkl") weights_file = loader.get_file(f"hvae_{self.version}.weights.h5") # Build the model architecture from args.pkl with open(args_file, "rb") as f: args = pickle.load(f) params = Parameters(args) vae = VAE(params) vae.build() # Load and copy the weights vae.load_weights(weights_file) vae.trainable = False self.network = vae # Set model parameters that are used in partial_inference self.depth = params.model_depth self.stage_depth = params.dec_num_blocks self.z_out = params.z_out
[docs] def sample(self, n_samples=1, **kwargs): """ Samples from the prior distribution. Args: n_samples (int): Number of samples to generate. Returns: samples (tensor): Generated samples of shape (n_samples, 256, 256, 3) in [-1, 1]. """ logits = self.network.decoder.call_uncond(n_samples, **kwargs) # Returns a 100 channel mixture of logistic functions (logits). samples = self.network.sample_from_mol(logits) return samples
[docs] def posterior_sample(self, measurements, n_samples=1, **kwargs): """ Performs posterior sampling on a batch of measurements. Only does a single encoder pass since it is deterministic, but does n_samples decoder passes to create posterior samples. Args: measurements (tensor): Input measurements of shape [B, 256, 256, 3]. n_samples (int, optional): Number of posterior samples to generate. Defaults to 1. Returns: output (tensor): Posterior samples of shape [B, n_samples, 256, 256, 3]. """ # Measurements is [B, 256, 256, 3] in [-1, 1] b = ops.shape(measurements)[0] # Only need a single deterministic encoder pass activations = self.network.encoder(measurements) # Repeat the tensors in the list of activations n_samples amount of times # This repeats elementwise, so: [1, 2, 3] -> [1, 1, 2, 2, 3, 3] activations = [ops.repeat(a, repeats=n_samples, axis=0) for a in activations] # Logits are of shape [B * n_samples, 256, 256, 100] logits, _, _ = self.network.decoder.call(activations) # Samples are of shape [B * n_samples, 256, 256, 3] in [-1, 1] samples = self.network.sample_from_mol(logits) # Split the samples into [B, n_samples, 256, 256, 3] output = ops.stack(ops.split(samples, b, axis=0), axis=0) return output
[docs] def call(self, measurements): """ Returns a reconstruction of the input, together with the latent samples and KL divergences. Args: measurements (tensor): Input measurements of shape [B, 256, 256, 3]. Returns: recon (tensor): Reconstructed output of shape [B, 256, 256, 3], List of latent samples from the decoder, and list of KL divergences at each latent layer. """ # Returns reconstruction, latent samples, kl divergences recon, z_samples, kl = self.network.call(measurements) recon = self.network.sample_from_mol(recon) return recon, z_samples, kl
[docs] def partial_inference(self, measurements, num_layers=0.5, n_samples=1, **kwargs): """ Performs TopDown inference with the HVAE up until a certain layer, after which it continues in the decoder with multiple prior streams. Args: measurements (tensor): Input measurements of shape [B, 256, 256, 3]. num_layers (float or int): If float, fraction of total layers to use from the top. If int, number of layers to use from the top. n_samples (int): Number of posterior samples to generate. Returns: output (tensor): Posterior samples of shape [B, n_samples, 256, 256, 3]. """ # Make sure num_layers is a float between 0 and 1 or an integer between 1 and depth if isinstance(num_layers, float): assert 0.0 < num_layers <= 1.0, "num_layers as float must be in (0.0, 1.0]" num_layers = int(num_layers * self.depth) elif isinstance(num_layers, int): assert 1 <= num_layers <= self.depth, f"num_layers as int must be in [1, {self.depth}]" else: raise ValueError("num_layers must be either a float or an int.") b = ops.shape(measurements)[0] # Only need a single deterministic encoder pass activations = self.network.encoder(measurements) # Single pass through the top num_layers of the decoder # Adding the same latent to z_stage n_samples times x = ops.zeros_like(activations[-1]) z = ops.tile(ops.zeros([1, *self.z_out]), (b * n_samples, 1, 1, 1)) current_layer = 0 for dec_stage, act in zip(self.network.decoder.stages.layers, reversed(activations)): for dec_block in dec_stage.blocks.layers: if current_layer < num_layers: # Use posterior sampling for the first num_layers x, z_block, _ = dec_block.call(x, act) z += ops.repeat(z_block, repeats=n_samples, axis=0) else: # Use prior sampling for the remaining layers if current_layer == num_layers: # At the threshold, we duplicate the rest of the chain x = ops.repeat(x, repeats=n_samples, axis=0) x, z_block = dec_block.call_uncond(x) z += z_block current_layer += 1 x = dec_stage.pool(x) z /= ops.sqrt(self.depth) px_z = self.network.decoder.activation(self.network.decoder.z_to_features(z)) for out_block in self.network.decoder.output_blocks.layers: px_z = out_block(px_z) px_z = self.network.decoder.last_conv(px_z) px_z = self.network.sample_from_mol(px_z) return ops.stack(ops.split(px_z, b, axis=0), axis=0)
[docs] def log_density(self, measurements, **kwargs): """ Calculates the log density (ELBO) of the data under the model. Args: measurements (tensor): Input measurements of shape [B, 256, 256, 3]. Returns: -elbo (tensor): negative ELBO of the input measurements, averaged over the batch. """ recon, _, kl = self.network.call(measurements) # elbo is averaged over batch dimension elbo, _, _ = self.network.get_elbo(measurements, recon, kl, **kwargs) return -elbo
register_presets(hvae_presets, HierarchicalVAE) __all__ = ["HierarchicalVAE"]