Expectation Maximization algorithm

Published:

I. Expectation Maximization

Expectation Maximization algorithm is widly used in latent variable setting such as semi-supervised, weakly supervised deep learning. We need to understand the EM algorithm before working on this kind of semi- and weakly- supervised settings in deep learning (check this post). I present here a intuitive example (basing on a short explanation in this notebook):

1. Setup

Suppose we have two sets of data points which are two different Gaussian distributions but we dont know which points belonging to which distribution. We will use the EM algorithm to cluster the data points into 2 cluster, corresponding to two Gaussian model and to find these two models parameters (2 components of the Gaussian Mixture Model).

%matplotlib inline
import numpy as np
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import random
from scipy.stats import multivariate_normal     # for generating pdf
np.random.seed(110) # for reproducible results
random.seed(1010) # for reproducible results

m1 = [1,1]      # consider a random mean and covariance value
m2 = [7,7]                                              
cov1 = [[3, 2], [2, 3]]                                      
cov2 = [[2, -1], [-1, 2]]
x = np.random.multivariate_normal(m1, cov1, size=(200,))  # Generating 200 samples for each mean and covariance
y = np.random.multivariate_normal(m2, cov2, size=(200,))
data = np.concatenate((x, y), axis=0)

def plotData(data):
    plt.figure(figsize=(10,10))         
    colors = cm.rainbow(np.linspace(0, 1, len(data)))
    for d, c in zip(data, colors):
        plt.scatter(d[:,0], d[:,1], color=c, marker='o') 
#     plt.scatter(data[:,0], data[:,1], marker='o')     
    plt.axis('equal')                                  
    plt.xlabel('X-Axis', fontsize=16)              
    plt.ylabel('Y-Axis', fontsize=16)                     
    plt.title('Ground Truth', fontsize=22)    
    plt.grid()            
    plt.show()
plotData([x,y])  

Data

2. Expectation Maximization algorithm

The Expectation Maximization algorithm is described as follows:

  • We make initial guess (random) of the parameters of 2 components Gaussian mixture model (mean, covariance matrix and mixture probabilities $p_1$ and $p2$).
  • EStep: Using the current model paramters, we estimate the likelihood that each point is drawn from a model (using the probability density function). Then we compute the weight that each point belongs to a distribution.
  • MStep: Using the estimated probabilities, we update the parameters of the GMM to maximize the likelihood
  • Repeat EStep-MStep until the parameters have converged, or a number of iterations has been reached.

We make initial guess (random) of the GMM.

m1 = random.choice(data)
m2 = random.choice(data)
cov1 = np.cov(np.transpose(data))
cov2 = np.cov(np.transpose(data))
p1 = 0.5
p2 = 1-p1
#Probability density function of 2 Gaussian distributions : Initial State
def plotPDF(m1,m2,cov1,cov2, data):
    x1 = np.linspace(-4,11,200)  
    x2 = np.linspace(-4,11,200)
    X, Y = np.meshgrid(x1,x2) 

    Z1 = multivariate_normal(m1, cov1)  
    Z2 = multivariate_normal(m2, cov2)

    pos = np.empty(X.shape + (2,))                # a new array of given shape and type, without initializing entries
    pos[:, :, 0] = X; pos[:, :, 1] = Y   
    print (pos.shape)
    plt.figure(figsize=(10,10))                                                          # creating the figure and assigning the size
    plt.scatter(data[:,0], data[:,1], marker='o')     
    plt.contour(X, Y, Z1.pdf(pos), colors="r" ,alpha = 0.5) 
    plt.contour(X, Y, Z2.pdf(pos), colors="b" ,alpha = 0.5) 
    plt.axis('equal')                                                                  # making both the axis equal
    plt.xlabel('X-Axis', fontsize=16)                                                  # X-Axis
    plt.ylabel('Y-Axis', fontsize=16)                                                  # Y-Axis
    plt.title('GMM : Initial State', fontsize=22)                                            # Title of the plot
    plt.grid()                                                                         # displaying gridlines
    plt.show()
    
plotPDF(m1,m2,cov1,cov2, data)        

Initial State

The function below is used to compute the weight from the likelihood of each points for each group

def groupWeight(group_likelihood, total_likelihood):
    """
    Compute the weight for each group at each data point.
    group_likelihood is a vector of the likelihood (w.r.t a group) for all points
    total_likelihood is a vector of the sum of the likelihood of each point (w.r.t two groups)
    """
    return group_likelihood / total_likelihood

The two functions below describe the Expectation step and Maximization step. (Loop over data points instead of matrix multiplication for easy understanding)

def Estep(data, m1,m2,cov1,cov2, p1):    
    """
    Expectation step : compute the likelihood of each data points for each distribution
    m1,m2,cov1,cov2 is the current parameters of the two Gaussian models
    """
    pt2 = multivariate_normal.pdf(data, mean=m2, cov=cov2)
    pt1 = multivariate_normal.pdf(data, mean=m1, cov=cov1)
    total = p1*pt1+(1-p1)*pt2
    return groupWeight(pt1,total), groupWeight(pt2,total), np.sum(groupWeight(pt1,total))/len(data)

def Mstep(data, weight1, weight2): 
    """
    Maximization step : maximize the likelihood of each data points
    """
    num_mu1,din_mu1,num_mu2,din_mu2=0,0,0,0

    for i in range(0,len(data)):
        num_mu1 += weight1[i] * data[i]
        din_mu1 += weight1[i]

        num_mu2 += weight2[i] * data[i]
        din_mu2 += weight2[i]

    mu1 = num_mu1/din_mu1
    mu2 = num_mu2/din_mu2

    num_s1,din_s1,num_s2,din_s2=0,0,0,0
    for i in range(0,len(data)):

        q1 = np.matrix(data[i]-mu1)
        num_s1 += pt1[i] * np.dot(q1.T, q1)
        din_s1 += pt1[i]

        q2 = np.matrix(data[i]-mu2)
        num_s2 += pt2[i] * np.dot(q2.T, q2)
        din_s2 += pt2[i]

    s1 = num_s1/din_s1
    s2 = num_s2/din_s2

    return mu1,mu2,s1,s2

3. Run EM algorithm and draw the animation

def animate(i): 
    """
    Repeat EStep and MStep : we put the iteration process in this function because 
    we need the save the animation.
    """
    global m1,m2,cov1,cov2,p1
    global cont1, cont2
    #Estimate the GMM model
    pt1, pt2, p1 = Estep(data, m1,m2,cov1,cov2, p1)   
    #Maximize the likelihood using current estimated GMM model
    m1,m2,cov1,cov2 = Mstep(data, pt1, pt2)
    
    #the two Gaussian components
    Z1 = multivariate_normal(m1, cov1)
    Z2 = multivariate_normal(m2, cov2)

    # We nee clear the drawing from previous iterations
    if cont1 is not None:
        for c in cont1.collections:
            c.remove()  # removes only the contours, leaves the rest intact        
    if cont2 is not None:
        for c in cont2.collections:
            c.remove()  # removes only the contours, leaves the rest intact

    #Plot the GMM model 
    cont1 = plt.contour(X, Y, Z1.pdf(pos), colors="r" ,alpha = 0.5)
    cont2 = plt.contour(X, Y, Z2.pdf(pos), colors="b" ,alpha = 0.5) 
    plt.title(r't = %i' % i)
    return cont1, cont2
 

import matplotlib.animation as animation   
# Set up formatting for the movie files
Writer = animation.writers['ffmpeg']
writer = Writer(fps=15, metadata=dict(artist='Me'), bitrate=1800)

cmap = plt.get_cmap('jet')    
iterations = 60
pi = 0.7
ims = []
fig = plt.figure(figsize=(10,10)) 
plt.axis('equal')             # making both the axis equal       
plt.scatter(data[:,0], data[:,1], marker='o')                                               
plt.xlabel('X-Axis', fontsize=16)                                                  # X-Axis
plt.ylabel('Y-Axis', fontsize=16)                                                  # Y-Axis
plt.grid()    
x1 = np.linspace(-4,11,200)  
x2 = np.linspace(-4,11,200)
X, Y = np.meshgrid(x1,x2) 
pos = np.empty(X.shape + (2,))                # a new array of given shape and type, without initializing entries
pos[:, :, 0] = X; pos[:, :, 1] = Y

cont1, cont2 = None, None
anim = animation.FuncAnimation(fig, animate, interval=300, frames=iterations)
anim.save('animation.mp4')

Animation EM