从0实现高斯混合模型(EM-GMM)

Problem:

  1. Please build a Gaussian mixture model (GMM) to model the data in file TrainingData_GMM.csv. Note that the data is composed of 4 clusters, and the model should be trained by expectation maximization (EM) algorithm.

  2. Based on the GMM learned above, assign each training data point into one of 4 different clusters

Questions:

1) Show how the log-likelihood evolves as the training proceeds


image

x轴为迭代次数,y轴为log-likelihood值

2) The learned mathematical expression for the GMM model after training on the given dataset

\alpha=\begin{bmatrix}0.23048224536932024\\0.22999999854996792\\0.272826418924052\\0.2666913371566595\end{bmatrix}

\mu = \begin{bmatrix}-0.40658&0.32248\\ 1.20354&-1.19686\\ 0.14435&0.14614\\ -0.44149&-0.45088\\\end{bmatrix}

\sigma = \begin{bmatrix} \begin{bmatrix}0.03446&-0.01299\\ -0.01299&0.03458\\\end{bmatrix} \begin{bmatrix}0.02259&-0.00761\\ -0.00761&0.02361\\\end{bmatrix} \begin{bmatrix}0.00886&0.00187\\ 0.00187&0.00881\\\end{bmatrix} \begin{bmatrix}0.07024&0.03731\\ 0.03731&0.06498\\\end{bmatrix} \end{bmatrix}
3) Randomly select 500 data points from the given dataset and plot them on a 2dimensional coordinate system. Mark the data points coming from the same cluster (using the results of Problem 2) with the same color.

image

4) Some analyses on the impacts of initialization on the converged values of EM algorithm
不同的初始参数对EM-GMM算法最后收敛的效果影响非常大,我的
\mu
是随机生成的,最佳收敛为-1946左右,但是也在-3000,-2000收敛过

5) Some analyses on the results you obtained
从结果上来看,算法已经可以很好地区分4个类别,说明EM-GMM算法在这种含有隐变量的问题,尤其是聚类问题中有着不错的效果。这次作业我也从推导入手,弄明白了EM与GMM,以及二维高斯分布等问题,收获颇丰!

import numpy as np
# Read data from csv
data = np.genfromtxt('TrainingData_GMM.csv', delimiter=',')
# n = num of sample, d = dimension of data, in this case, is 2-d.
n,d = np.shape(data)
# k = num of cluster, set to 4 by the problem describtion.
k = 4
# iter
iter = 500
# 2-d guassin distribution
def guassin_distribution(_x,_miu,_sigma):
    tmp1 = 1/((2*np.pi)*np.linalg.det(_sigma)**(1/2))
    tmp2 = np.exp(-0.5*(_x-_miu)@np.linalg.inv(_sigma)@(_x-_miu))
    return tmp1*tmp2
# calculate log-likelihood of given parameters
def log_likelihood():
    prob = np.zeros([n,k])
    for i in range(n):
        for j in range(k):
            prob[i, j] = guassin_distribution(data[i],miu[j],sigma[j]) 
    return np.sum(np.log(prob@alpha))
# E step, calculate gamma[y,j].
def E():
    gamma = np.zeros([n,k])
    for i in range(n):
        for j in range(k):
            gamma[i,j] = alpha[j] * guassin_distribution(data[i],miu[j],sigma[j]) 
#     print(np.sum(gamma,axis=1).reshape((n,1)))
    new_gamma = gamma / np.sum(gamma,axis=1).reshape((n,1))
    return new_gamma,gamma
# M step, calculate the latest parameters
def M(gamma):
    K_sum_gamma = np.sum(gamma,axis=0)
    miu_tmp = np.zeros((k, d))
    sigma_tmp = np.zeros((k, d, d))
    alpha_tmp = np.zeros(k)
    for j in range(k):
        tmp_miu = 0
        tmp_sigma = 0
        for i in range(n):
            tmp_miu += gamma[i,j]*data[i]
            tmp_sigma += gamma[i,j]*(data[i]-miu[j]).reshape(2,1)@(data[i]-miu[j]).reshape(1,2)
        miu_tmp[j] = tmp_miu/K_sum_gamma[j]
        sigma_tmp[j] = tmp_sigma/K_sum_gamma[j]
        alpha_tmp[j] = K_sum_gamma[j]/n
    return miu_tmp,sigma_tmp,alpha_tmp
# parmameter initialization
# miu,sigma is the parameter of gaussian distribution
# 因为这里是二维的高斯分布,所以sigma在这里是2维的协方差(Dim*Dim),当处理一维高斯分布的时候我们通常把sigma^2看作方差

miu = np.random.randint(0,2,(k,d))
sigma = np.random.randint(0,2,(k,d,d))
alpha = np.array([0.25,0.25,0.25,0.25])
print(miu)
for i in range(0, k):
    sigma[i] = np.diag([1,1])
    
index = 0
likehood_list = []
while 1 :
    index += 1
    gamma,_ = E()
    miu, sigma, alpha = M(gamma)
    loglike  = log_likelihood()
    likehood_list.append(loglike)
    if index > 5 and  abs(loglike - likehood_list[-2]) < 0.01 :
        print('Finish training')
        break
#     loglike = log_likelihood()
    print('index %d, log likehood %f'%(index,loglike))
[[1 1]
 [1 0]
 [0 0]
 [0 1]]
[ 611.45229966 1430.6772165  1926.06578173 1031.80470211]
index 1, log likehood -9672.165311
[ 335.75137577 1331.29076907 2399.18394139  933.77391378]
index 2, log likehood -7516.916844
[ 295.54773607 1315.88106212 2098.17872769 1290.39247412]
index 3, log likehood -6732.294851
[ 237.72130526 1330.64190269 1722.02088632 1709.61590574]
index 4, log likehood -5824.873937
[ 207.66570465 1342.17076547 1470.71920008 1979.44432979]
index 5, log likehood -5031.472039
[ 197.38485415 1269.03430146 1470.63230013 2062.94854426]
index 6, log likehood -4229.474246
[ 179.51063735 1174.58748844 1603.0608239  2042.8410503 ]
index 7, log likehood -3351.133138
[ 171.67195764 1150.05077496 1709.70143046 1968.57583694]
index 8, log likehood -2959.046982
[ 178.79795077 1149.99864314 1792.51929472 1878.68411137]
index 9, log likehood -2927.094710
[ 193.96805643 1149.99873754 1861.41354903 1794.619657  ]
index 10, log likehood -2905.027898
[ 213.54444589 1149.99946804 1917.28085395 1719.17523212]
index 11, log likehood -2886.123975
[ 236.35308699 1149.99983307 1960.1619301  1653.48514984]
index 12, log likehood -2868.333690
[ 262.01822615 1149.99992085 1989.96797803 1598.01387497]
index 13, log likehood -2850.162029
[ 290.40638865 1149.99994905 2006.96904628 1552.62461602]
index 14, log likehood -2831.098999
[ 320.9322868  1149.99996915 2012.72943982 1516.33830423]
index 15, log likehood -2812.377171
[ 352.0658956  1149.99998355 2010.51437783 1487.41974302]
index 16, log likehood -2795.924313
[ 381.96342042 1149.99999176 2003.92893738 1464.10765044]
index 17, log likehood -2782.622865
[ 409.29180942 1149.99999574 1995.47462626 1445.23356858]
index 18, log likehood -2772.198095
[ 433.42623915 1149.99999759 1986.46653381 1430.10722945]
index 19, log likehood -2763.908830
[ 454.32611636 1149.99999847 1977.46869471 1418.20519047]
index 20, log likehood -2756.982532
[ 472.3401729  1149.99999891 1968.60577922 1409.05404897]
index 21, log likehood -2750.755481
[ 488.03176146 1149.99999913 1959.72715051 1402.2410889 ]
index 22, log likehood -2744.664587
[ 502.07146729 1149.99999924 1950.48262339 1397.44591008]
index 23, log likehood -2738.171491
[ 515.21134548 1149.99999928 1940.33514025 1394.453515  ]
index 24, log likehood -2730.646916
[ 528.34343007 1149.99999927 1928.51915412 1393.13741655]
index 25, log likehood -2721.213850
[ 542.65002974 1149.99999921 1913.95371358 1393.39625747]
index 26, log likehood -2708.543565
[ 559.8341125  1149.9999991  1895.14228341 1395.02360499]
index 27, log likehood -2690.605833
[ 582.32328095 1149.99999894 1870.14857631 1397.52814381]
index 28, log likehood -2664.390693
[ 613.17479658 1149.99999871 1836.83467602 1399.99052869]
index 29, log likehood -2625.841179
[ 655.23092526 1149.99999839 1793.67419024 1401.09488611]
index 30, log likehood -2570.724441
[ 709.4085606  1149.99999793 1741.09910613 1399.49233533]
index 31, log likehood -2497.122214
[ 773.45862031 1149.99999729 1682.02719971 1394.51418269]
index 32, log likehood -2406.763618
[ 842.49057709 1149.99999645 1620.80983062 1386.69959584]
index 33, log likehood -2301.998774
[ 910.83901083 1149.9999955  1561.73542083 1377.42557284]
index 34, log likehood -2188.814428
[ 973.19529188 1149.99999457 1508.32843011 1368.47628345]
index 35, log likehood -2082.601753
[1024.65287295 1149.99999384 1463.98573368 1361.36139953]
index 36, log likehood -2006.034687
[1062.59718728 1149.99999346 1430.74073067 1356.66208859]
index 37, log likehood -1968.437544
[1088.75814826 1149.99999334 1407.47964219 1353.76221621]
index 38, log likehood -1954.693990
[1106.59286891 1149.99999327 1391.88124644 1351.52589138]
index 39, log likehood -1949.858506
[1118.97485693 1149.99999317 1381.70521529 1349.31993461]
index 40, log likehood -1947.959181
[1127.78707665 1149.99999306 1375.16996326 1347.04296702]
index 41, log likehood -1947.122789
[1134.20779982 1149.99999298 1371.00469136 1344.78751583]
index 42, log likehood -1946.719193
[1138.98103843 1149.99999292 1368.36206003 1342.65690863]
index 43, log likehood -1946.509007
[1142.589615   1149.99999287 1366.69310582 1340.71728632]
index 44, log likehood -1946.392227
[1145.35625181 1149.99999283 1365.64552897 1338.99822638]
index 45, log likehood -1946.323806
[1147.50243001 1149.99999281 1364.99377606 1337.50380112]
index 46, log likehood -1946.282033
[1149.18368366 1149.99999279 1364.59348535 1336.22283821]
index 47, log likehood -1946.255744
[1150.51148373 1149.99999277 1364.35227103 1335.13625247]
index 48, log likehood -1946.238844
[1151.56720095 1149.99999276 1364.21104918 1334.22175711]
index 49, log likehood -1946.227821
[1152.41122685 1149.99999275 1364.13209462 1333.45668578]
Finish training
import pylab
# %matplotlib inline 
%config InlineBackend.figure_format = 'png'
pylab.style.use('default')
pylab.plot(likehood_list)
[<matplotlib.lines.Line2D at 0x122913908>]
image
node_num = 500
_,gamma=E()
label = np.argmax(gamma,1)
selected_node_index = np.random.choice(range(n),size=node_num)
node_pos = data[selected_node_index]
label = label[selected_node_index]
pylab.scatter(node_pos[:,0],node_pos[:,1],marker='o',c=label,cmap=pylab.cm.Accent)
<matplotlib.collections.PathCollection at 0x1212d0b00>
image
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,362评论 5 477
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,330评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,247评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,560评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,580评论 5 365
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,569评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,929评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,587评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,840评论 1 297
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,596评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,678评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,366评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,945评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,929评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,165评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 43,271评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,403评论 2 342

推荐阅读更多精彩内容