Published on

Variational Autoencoder

Authors

Title: Variational Autoencoder

Author: Antonio Lorenzo

Subject: Machine learning

Language: English

Summary

In this project, we explore the construction and application of Variational Autoencoders (VAEs) to generate content using deep learning techniques. We implement the VAE model using Keras and TensorFlow, leveraging the Fashion MNIST dataset for training. The model consists of an encoder that maps the input data to a latent space and a decoder that reconstructs the data from the latent space, allowing us to generate new data samples by sampling from this space. VAEs are commonly used in generative modeling, where the goal is to create realistic data samples similar to the original dataset.

Data Loading and Preprocessing:

We start by loading the Fashion MNIST dataset, which consists of grayscale images of clothing items, each of size 28x28 pixels. The dataset is normalized to have values between 0 and 1 to enhance training efficiency.

import keras
import tensorflow._api.v2.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Input, Flatten, Dense, Lambda, Reshape
#from keras.layers import BatchNormalization
from keras.models import Model
from keras.datasets import fashion_mnist
from keras import backend as K
import numpy as np
import matplotlib.pyplot as plt

# Load MNIST
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

Normalize and reshape

#Norm.
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train = x_train / 255
x_test = x_test / 255

# Reshape
img_width  = x_train.shape[1]
img_height = x_train.shape[2]
num_channels = 1 #MNIST --> grey scale so 1 channel
x_train = x_train.reshape(x_train.shape[0], img_height, img_width, num_channels)
x_test = x_test.reshape(x_test.shape[0], img_height, img_width, num_channels)
input_shape = (img_height, img_width, num_channels)

Visualizing the Data:

We visualize a few samples from the dataset to gain an understanding of what kind of images we are working with.

plt.figure(1)
plt.subplot(221)
plt.imshow(x_train[42][:,:,0])

plt.subplot(222)
plt.imshow(x_train[420][:,:,0])

plt.subplot(223)
plt.imshow(x_train[4200][:,:,0])

plt.subplot(224)
plt.imshow(x_train[42000][:,:,0])
plt.show()
blog-image

Build the model

Building the Encoder:

The encoder part of the VAE is designed using multiple convolutional layers, which help in extracting hierarchical features from the input images. The output of the encoder is parameterized by two vectors, z_mu and z_sigma, which represent the mean and variance of the latent space distribution.

#Let us define 4 conv2D, flatten and then dense
# # ================= ############

latent_dim = 2 # Number of latent dim parameters

input_img = Input(shape=input_shape, name='encoder_input')
x = Conv2D(32, 3, padding='same', activation='relu')(input_img)
x = Conv2D(64, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2D(64, 3, padding='same', activation='relu')(x)
x = Conv2D(64, 3, padding='same', activation='relu')(x)

conv_shape = K.int_shape(x) #Shape of conv to be provided to decoder
#Flatten
x = Flatten()(x)
x = Dense(32, activation='relu')(x)

# Two outputs, for latent mean and log variance (std. dev.)
#Use these to sample random variables in latent space to which inputs are mapped.
z_mu = Dense(latent_dim, name='latent_mu')(x)   #Mean values of encoded input
z_sigma = Dense(latent_dim, name='latent_sigma')(x)  #Std dev. (variance) of encoded input

#REPARAMETERIZATION TRICK
# Define sampling function to sample from the distribution
# Reparameterize sample based on the process defined by Gunderson and Huang
# into the shape of: mu + sigma squared x eps
#This is to allow gradient descent to allow for gradient estimation accurately.
def sample_z(args):
  z_mu, z_sigma = args
  eps = K.random_normal(shape=(K.shape(z_mu)[0], K.int_shape(z_mu)[1]))
  return z_mu + K.exp(z_sigma / 2) * eps

# sample vector from the latent distribution
# z is the labda custom layer we are adding for gradient descent calculations
  # using mu and variance (sigma)
z = Lambda(sample_z, output_shape=(latent_dim, ), name='z')([z_mu, z_sigma])

#Z (lambda layer) will be the last layer in the encoder.
# Define and summarize encoder model.
encoder = Model(input_img, [z_mu, z_sigma, z], name='encoder')
print(encoder.summary())

Model: "encoder"
    __________________________________________________________________________________________________
     Layer (type)                Output Shape                 Param #   Connected to                  
    ==================================================================================================
     encoder_input (InputLayer)  [(None, 28, 28, 1)]          0         []                            
                                                                                                      
     conv2d_4 (Conv2D)           (None, 28, 28, 32)           320       ['encoder_input[0][0]']       
                                                                                                      
     conv2d_5 (Conv2D)           (None, 14, 14, 64)           18496     ['conv2d_4[0][0]']            
                                                                                                      
     conv2d_6 (Conv2D)           (None, 14, 14, 64)           36928     ['conv2d_5[0][0]']            
                                                                                                      
     conv2d_7 (Conv2D)           (None, 14, 14, 64)           36928     ['conv2d_6[0][0]']            
                                                                                                      
     flatten_1 (Flatten)         (None, 12544)                0         ['conv2d_7[0][0]']            
                                                                                                      
     dense_2 (Dense)             (None, 32)                   401440    ['flatten_1[0][0]']           
                                                                                                      
     latent_mu (Dense)           (None, 2)                    66        ['dense_2[0][0]']             
                                                                                                      
     latent_sigma (Dense)        (None, 2)                    66        ['dense_2[0][0]']             
                                                                                                      
     z (Lambda)                  (None, 2)                    0         ['latent_mu[0][0]',           
                                                                         'latent_sigma[0][0]']        
                                                                                                      
    ==================================================================================================
    Total params: 494244 (1.89 MB)
    Trainable params: 494244 (1.89 MB)
    Non-trainable params: 0 (0.00 Byte)
    __________________________________________________________________________________________________
    None

The reparameterization trick allows us to backpropagate through the random sampling operation by combining the latent mean and variance with a randomly drawn sample from a normal distribution.

Building the Decoder

The decoder maps the latent vector back to the original image space. It uses convolutional transpose layers to "upscale" the latent space representation back into an image.

# decoder takes the latent vector as input
decoder_input = Input(shape=(latent_dim, ), name='decoder_input')

# Need to start with a shape that can be remapped to original image shape as
#we want our final utput to be same shape original input.
#So, add dense layer with dimensions that can be reshaped to desired output shape
x = Dense(conv_shape[1]*conv_shape[2]*conv_shape[3], activation='relu')(decoder_input)
# reshape to the shape of last conv. layer in the encoder, so we can
x = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(x)
# upscale (conv2D transpose) back to original shape
# use Conv2DTranspose to reverse the conv layers defined in the encoder
x = Conv2DTranspose(32, 3, padding='same', activation='relu',strides=(2, 2))(x)
#Can add more conv2DTranspose layers, if desired.
#Using sigmoid activation
x = Conv2DTranspose(num_channels, 3, padding='same', activation='sigmoid', name='decoder_output')(x)

# Define and summarize decoder model
decoder = Model(decoder_input, x, name='decoder')
decoder.summary()

# apply the decoder to the latent sample
z_decoded = decoder(z)
 Model: "decoder"
    _________________________________________________________________
     Layer (type)                Output Shape              Param #   
    =================================================================
     decoder_input (InputLayer)  [(None, 2)]               0         
                                                                     
     dense_3 (Dense)             (None, 12544)             37632     
                                                                     
     reshape_1 (Reshape)         (None, 14, 14, 64)        0         
                                                                     
     conv2d_transpose_1 (Conv2D  (None, 28, 28, 32)        18464     
     Transpose)                                                      
                                                                     
     decoder_output (Conv2DTran  (None, 28, 28, 1)         289       
     spose)                                                          
                                                                     
    =================================================================
    Total params: 56385 (220.25 KB)
    Trainable params: 56385 (220.25 KB)
    Non-trainable params: 0 (0.00 Byte)
    _________________________________________________________________

Define a custom loss

The VAE is trained using two loss functions:

  • Reconstruction Loss: Measures how well the decoder reconstructs the input image.
  • KL Divergence: Ensures that the learned latent distribution is close to a standard normal distribution.
#VAE is trained using two loss functions reconstruction loss and KL divergence
#Let us add a class to define a custom layer with loss
class CustomLayer(keras.layers.Layer):

    def vae_loss(self, x, z_decoded):
        x = K.flatten(x)
        z_decoded = K.flatten(z_decoded)

        # Reconstruction loss (as we used sigmoid activation we can use binarycrossentropy)
        recon_loss = keras.metrics.binary_crossentropy(x, z_decoded)

        # KL divergence
        kl_loss = -5e-4 * K.mean(1 + z_sigma - K.square(z_mu) - K.exp(z_sigma), axis=-1)
        return K.mean(recon_loss + kl_loss)

    # add custom loss to the class
    def call(self, inputs):
        x = inputs[0]
        z_decoded = inputs[1]
        loss = self.vae_loss(x, z_decoded)
        self.add_loss(loss, inputs=inputs)
        return x

# apply the custom loss to the input images and the decoded latent distribution sample
y = CustomLayer()([input_img, z_decoded])
# y is basically the original image after encoding input img to mu, sigma, z
# and decoding sampled z values.
#This will be used as output for vae

Variational Autoencoder Model (VAE)

Finally, we combine the encoder and decoder into a single VAE model and compile it using the custom loss function.

vae = Model(input_img, y, name='vae')

# Compile VAE
vae.compile(optimizer='adam', loss=None)
vae.summary()
Model: "vae"
    __________________________________________________________________________________________________
     Layer (type)                Output Shape                 Param #   Connected to                  
    ==================================================================================================
     encoder_input (InputLayer)  [(None, 28, 28, 1)]          0         []                            
                                                                                                      
     conv2d_4 (Conv2D)           (None, 28, 28, 32)           320       ['encoder_input[0][0]']       
                                                                                                      
     conv2d_5 (Conv2D)           (None, 14, 14, 64)           18496     ['conv2d_4[0][0]']            
                                                                                                      
     conv2d_6 (Conv2D)           (None, 14, 14, 64)           36928     ['conv2d_5[0][0]']            
                                                                                                      
     conv2d_7 (Conv2D)           (None, 14, 14, 64)           36928     ['conv2d_6[0][0]']            
                                                                                                      
     flatten_1 (Flatten)         (None, 12544)                0         ['conv2d_7[0][0]']            
                                                                                                      
     dense_2 (Dense)             (None, 32)                   401440    ['flatten_1[0][0]']           
                                                                                                      
     latent_mu (Dense)           (None, 2)                    66        ['dense_2[0][0]']             
                                                                                                      
     latent_sigma (Dense)        (None, 2)                    66        ['dense_2[0][0]']             
                                                                                                      
     z (Lambda)                  (None, 2)                    0         ['latent_mu[0][0]',           
                                                                         'latent_sigma[0][0]']        
                                                                                                      
     decoder (Functional)        (None, 28, 28, 1)            56385     ['z[0][0]']                   
                                                                                                      
     custom_layer_1 (CustomLaye  (None, 28, 28, 1)            0         ['encoder_input[0][0]',       
     r)                                                                  'decoder[0][0]']             
                                                                                                      
    ==================================================================================================
    Total params: 550629 (2.10 MB)
    Trainable params: 550629 (2.10 MB)
    Non-trainable params: 0 (0.00 Byte)
    __________________________________________________________________________________________________

Training the VAE

We train the VAE for 10 epochs, and monitor the loss during training.

# Train autoencoder
vae.fit(x_train, None, epochs = 10, batch_size = 32, validation_split = 0.2)
Train on 48000 samples, validate on 12000 samples
    Epoch 1/10
    48000/48000 [==============================] - ETA: 0s - loss: 0.3641

    
48000/48000 [==============================] - 215s 4ms/sample - loss: 0.3641 - val_loss: 0.3439
    Epoch 2/10
    48000/48000 [==============================] - 214s 4ms/sample - loss: 0.3400 - val_loss: 0.3400
    Epoch 3/10
    48000/48000 [==============================] - 212s 4ms/sample - loss: 0.3364 - val_loss: 0.3386
    Epoch 4/10
    48000/48000 [==============================] - 213s 4ms/sample - loss: 0.3346 - val_loss: 0.3351
    Epoch 5/10
    48000/48000 [==============================] - 213s 4ms/sample - loss: 0.3331 - val_loss: 0.3339
    Epoch 6/10
    48000/48000 [==============================] - 213s 4ms/sample - loss: 0.3320 - val_loss: 0.3350
    Epoch 7/10
    48000/48000 [==============================] - 213s 4ms/sample - loss: 0.3310 - val_loss: 0.3328
    Epoch 8/10
    48000/48000 [==============================] - 214s 4ms/sample - loss: 0.3302 - val_loss: 0.3328
    Epoch 9/10
    48000/48000 [==============================] - 214s 4ms/sample - loss: 0.3293 - val_loss: 0.3312
    Epoch 10/10
    48000/48000 [==============================] - 214s 4ms/sample - loss: 0.3292 - val_loss: 0.3314

Visualize

#Visualize inputs mapped to the Latent space
#Remember that we have encoded inputs to latent space dimension = 2.
#Extract z_mu --> first parameter in the result of encoder prediction representing mean

mu, _, _ = encoder.predict(x_test)
#Plot dim1 and dim2 for mu
plt.figure(figsize=(10, 10))
plt.scatter(mu[:, 0], mu[:, 1], c=y_test, cmap='brg')
plt.xlabel('dim 1')
plt.ylabel('dim 2')
plt.colorbar()
plt.show()
blog-image
# Visualize images
#Single decoded image with random input latent vector (of size 1x2)
#Latent space range is about -5 to 5 so pick random values within this range
#Try starting with -1, 1 and slowly go up to -1.5,1.5 and see how it morphs from
#one image to the other.
sample_vector = np.array([[-4,4]])
decoded_example = decoder.predict(sample_vector)
decoded_example_reshaped = decoded_example.reshape(img_width, img_height)
plt.imshow(decoded_example_reshaped)

#Let us automate this process by generating multiple images and plotting
#Use decoder to generate images by tweaking latent variables from the latent space
#Create a grid of defined size with zeros.
#Take sample from some defined linear space. In this example range [-4, 4]
#Feed it to the decoder and update zeros in the figure with output.


n = 20  # generate 15x15 digits
figure = np.zeros((img_width * n, img_height * n, num_channels))

#Create a Grid of latent variables, to be provided as inputs to decoder.predict
#Creating vectors within range -5 to 5 as that seems to be the range in latent space
grid_x = np.linspace(-5, 5, n)
grid_y = np.linspace(-5, 5, n)[::-1]

# decoder for each square in the grid
for i, yi in enumerate(grid_y):
    for j, xi in enumerate(grid_x):
        z_sample = np.array([[xi, yi]])
        x_decoded = decoder.predict(z_sample)
        digit = x_decoded[0].reshape(img_width, img_height, num_channels)
        figure[i * img_width: (i + 1) * img_width,
               j * img_height: (j + 1) * img_height] = digit

plt.figure(figsize=(10, 10))
#Reshape for visualization
fig_shape = np.shape(figure)
figure = figure.reshape((fig_shape[0], fig_shape[1]))

plt.imshow(figure, cmap='gnuplot2')
plt.show()
blog-imageblog-image

Conclusion

In this project, we implemented a Variational Autoencoder that can generate new samples from the latent space of the Fashion MNIST dataset. The use of a variational approach allows for continuous exploration of the latent space, offering flexibility in generating diverse and novel data points. VAEs have significant potential in various fields such as image generation, anomaly detection, and data compression.