import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets.samples_generator import make_blobs
from sklearn.datasets.samples_generator import make_circles
K = 4 # 类别数目
MAX_ITERS = 1000 # 最大迭代次数
N = 200 # 样本点数目
centers = [[-2, -2], [-2, 1.5], [1.5, -2], [2, 1.5]] # 簇中心
# 生成人工数据集
#data, features = make_circles(n_samples=200, shuffle=True, noise=0.1, factor=0.4)
data, features = make_blobs(n_samples=N, centers=centers, n_features = 2, cluster_std=0.8, shuffle=False, random_state=42)
print(data)
print(features)
# 计算类内平均值函数
def clusterMean(data, id, num):
total = tf.unsorted_segment_sum(data, id, num) # 第一个参数是tensor,第二个参数是簇标签,第三个是簇数目
count = tf.unsorted_segment_sum(tf.ones_like(data), id, num)
return total/count
# 构建graph
points = tf.Variable(data)
cluster = tf.Variable(tf.zeros([N], dtype=tf.int64))
centers = tf.Variable(tf.slice(points.initialized_value(), [0, 0], [K, 2]))# 将原始数据前k个点当做初始中心
repCenters = tf.reshape(tf.tile(centers, [N, 1]), [N, K, 2]) # 复制操作,便于矩阵批量计算距离
repPoints = tf.reshape(tf.tile(points, [1, K]), [N, K, 2])
sumSqure = tf.reduce_sum(tf.square(repCenters-repPoints), reduction_indices=2) # 计算距离
bestCenter = tf.argmin(sumSqure, axis=1) # 寻找最近的簇中心
change = tf.reduce_any(tf.not_equal(bestCenter, cluster)) # 检测簇中心是否还在变化
means = clusterMean(points, bestCenter, K) # 计算簇内均值
# 将粗内均值变成新的簇中心,同时分类结果也要更新
with tf.control_dependencies([change]):
update = tf.group(centers.assign(means),cluster.assign(bestCenter)) # 复制函数
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
changed = True
iterNum = 0
while changed and iterNum < MAX_ITERS:
iterNum += 1
# 运行graph
[changed, _] = sess.run([change, update])
[centersArr, clusterArr] = sess.run([centers, cluster])
print(clusterArr)
print(centersArr)
# 显示图像
fig, ax = plt.subplots()
ax.scatter(data.transpose()[0], data.transpose()[1], marker='o', s=100, c=clusterArr)
plt.plot()
plt.show()
这里需要注意的地方有:
1、unsorted_segment_sum函数是用来分割求和的,第二个参数就是分割的index,index相同的作为一个整体求和。
2、计算距离的时候使用了矩阵的批量运算,因此看起来不太直观,稍微推导一下就明白了。
3、tf.control_dependencies用来控制op运行顺序,只有检测类中心还在变化,再完成之后的更新操作。
4、tf.group是封装多个操作的函数。
5、画图函数内置在了训练过程中,因此每一轮迭代的结果都有显示,这是个很小的demo,因此迭代几轮后就可以收敛了。