Source code for zea.models.layers
"""Layers used in zea.models"""
import math
import keras
from keras import layers, ops
[docs]
@keras.saving.register_keras_serializable()
def sinusoidal_embedding(x, embedding_min_frequency, embedding_max_frequency, embedding_dims):
"""Sinusoidal embedding layer."""
frequencies = ops.exp(
ops.linspace(
ops.log(embedding_min_frequency),
ops.log(embedding_max_frequency),
embedding_dims // 2,
)
)
angular_speeds = ops.cast(2.0 * math.pi * frequencies, x.dtype)
embeddings = ops.concatenate(
[ops.sin(angular_speeds * x), ops.cos(angular_speeds * x)], axis=-1
)
return embeddings
[docs]
def ResidualBlock(width):
"""Residual block with swish activation."""
def apply(x):
input_width = ops.shape(x)[3]
if input_width == width:
residual = x
else:
residual = layers.Conv2D(width, kernel_size=1)(x)
x = layers.BatchNormalization(center=False, scale=False)(x)
x = layers.Conv2D(width, kernel_size=3, padding="same", activation="swish")(x)
x = layers.Conv2D(width, kernel_size=3, padding="same")(x)
x = layers.Add()([x, residual])
return x
return apply
[docs]
def DownBlock(width, block_depth):
"""Downsampling block with residual connections."""
def apply(x):
x, skips = x
for _ in range(block_depth):
x = ResidualBlock(width)(x)
skips.append(x)
x = layers.AveragePooling2D(pool_size=2)(x)
return x
return apply
[docs]
def UpBlock(width, block_depth):
"""Upsampling block with residual connections."""
def apply(x):
x, skips = x
x = layers.UpSampling2D(size=2, interpolation="bilinear")(x)
for _ in range(block_depth):
x = layers.Concatenate()([x, skips.pop()])
x = ResidualBlock(width)(x)
return x
return apply