二分Kmeans的java实现,二分kmeansjava


刚刚研究了Kmeans。Kmeans是一种十分简单的聚类算法。但是他十分依赖于用户最初给定的k值。它无法发现任意形状和大小的簇,最适合于发现球状簇。他的时间复杂度为O(tkn)。kmeans算法有两个核心点:计算距离的公式&判断迭代停止的条件。一般距采用欧式距离等可以随意。判断迭代停止的条件可以有:

1) 每个簇的中心点不再变化则停止迭代

2)所有簇的点与这个簇的中心点的误差平方和(SSE)的所有簇的总和不再变化

3)设定人为的迭代次数,观察实验效果。


当初始簇心选择不好的时候聚类的效果会很差。所以后来又有一个人提出了二分k均值(bisectingkmeans),其核心思路是:将初始的一个簇一分为二计算出误差平方和最大的那个簇,对他进行再一次的二分。直至切分的簇的个数为k个停止。 其实质就是不断的对选中的簇做k=2的kmeans切分。

因为聚类的误差平方和能够衡量聚类性能,该值越小表示数据点月接近于它们的质心,聚类效果就越好。所以我们就需要对误差平方和最大的簇进行再一次的划分,因为误差平方和越大,表示该簇聚类越不好,越有可能是多个簇被当成一个簇了,所以我们首先需要对这个簇进行划分。


下面是代码,kmeans的原始代码来源于http://blog.csdn.net/cyxlzzs/article/details/7416491,我稍作了一些修改。


package org.algorithm;

import java.util.ArrayList;
import java.util.List;

/**
 * 二分k均值,实际上是对一个集合做多次的k=2的kmeans划分, 每次划分后会对sse值较大的簇再进行二分。 最终使得或分出来的簇的个数为k个则停止
 * 
 * 这里利用之前别人写好的一个kmeans的java实现作为基础类。
 * 
 * @author l0979365428
 * 
 */
public class BisectingKmeans {

	private int k;// 分成多少簇
	private List<float[]> dataSet;// 当前要被二分的簇
	private List<ClusterSet> cluster; // 簇

	/**
	 * @param args
	 */
	public static void main(String[] args) {

		// 初始化一个Kmean对象,将k置为10
		BisectingKmeans bkm = new BisectingKmeans(5);
		// 初始化试验集
		ArrayList<float[]> dataSet = new ArrayList<float[]>();

		dataSet.add(new float[] { 1, 2 });
		dataSet.add(new float[] { 3, 3 });
		dataSet.add(new float[] { 3, 4 });
		dataSet.add(new float[] { 5, 6 });
		dataSet.add(new float[] { 8, 9 });
		dataSet.add(new float[] { 4, 5 });
		dataSet.add(new float[] { 6, 4 });
		dataSet.add(new float[] { 3, 9 });
		dataSet.add(new float[] { 5, 9 });
		dataSet.add(new float[] { 4, 2 });
		dataSet.add(new float[] { 1, 9 });
		dataSet.add(new float[] { 7, 8 });
		// 设置原始数据集
		bkm.setDataSet(dataSet);
		// 执行算法
		bkm.execute();
		// 得到聚类结果
		// ArrayList<ArrayList<float[]>> cluster = bkm.getCluster();
		// 查看结果
		// for (int i = 0; i < cluster.size(); i++) {
		// bkm.printDataArray(cluster.get(i), "cluster[" + i + "]");
		// }

	}

	public BisectingKmeans(int k) {
		// 比2还小有啥要划分的意义么
		if (k < 2) {
			k = 2;
		}
		this.k = k;

	}

	/**
	 * 设置需分组的原始数据集
	 * 
	 * @param dataSet
	 */

	public void setDataSet(ArrayList<float[]> dataSet) {
		this.dataSet = dataSet;
	}

	/**
	 * 执行算法
	 */
	public void execute() {
		long startTime = System.currentTimeMillis();
		System.out.println("BisectingKmeans begins");
		BisectingKmeans();
		long endTime = System.currentTimeMillis();
		System.out.println("BisectingKmeans running time="
				+ (endTime - startTime) + "ms");
		System.out.println("BisectingKmeans ends");
		System.out.println();
	}

	/**
	 * 初始化
	 */
	private void init() {

		int dataSetLength = dataSet.size();
		if (k > dataSetLength) {
			k = dataSetLength;
		}
	}

	/**
	 * 初始化簇集合
	 * 
	 * @return 一个分为k簇的空数据的簇集合
	 */
	private ArrayList<ArrayList<float[]>> initCluster() {
		ArrayList<ArrayList<float[]>> cluster = new ArrayList<ArrayList<float[]>>();
		for (int i = 0; i < k; i++) {
			cluster.add(new ArrayList<float[]>());
		}

		return cluster;
	}

	/**
	 * Kmeans算法核心过程方法
	 */
	private void BisectingKmeans() {
		init();

		if (k < 2) {
			// 小于2 则原样输出数据集被认为是只分了一个簇
			ClusterSet cs = new ClusterSet();
			cs.setClu(dataSet);
			cluster.add(cs);
		}
		// 调用kmeans进行二分
		cluster = new ArrayList();

		while (cluster.size() < k) {
			List<ClusterSet> clu = kmeans(dataSet);

			for (ClusterSet cl : clu) {

				cluster.add(cl);

			}

			if (cluster.size() == k)
				break;
			else// 顺序计算他们的误差平方和
			{
				
				float maxerro=0f;
				int maxclustersetindex=0;
				int i=0;
				for (ClusterSet tt : cluster) {
					//计算误差平方和并得出误差平方和最大的簇
					float erroe = CommonUtil.countRule(tt.getClu(), tt
							.getCenter());
					tt.setErro(erroe);
					
					if(maxerro<erroe)
					{
						maxerro=erroe;
						maxclustersetindex=i;
					}
					i++;
				}

				dataSet=cluster.get(maxclustersetindex).getClu();
				cluster.remove(maxclustersetindex);
				
			}
		}
		int i=0;
		for(ClusterSet sc:cluster)
		{
		CommonUtil.printDataArray(sc.getClu(),"cluster"+i);
		i++;
		}
		

	}

	/**
	 * 调用kmeans得到两个簇。
	 * 
	 * @param dataSet
	 * @return
	 */
	private List<ClusterSet> kmeans(List<float[]> dataSet) {
		Kmeans k = new Kmeans(2);

		// 设置原始数据集
		k.setDataSet(dataSet);
		// 执行算法
		k.execute();
		// 得到聚类结果
		List<List<float[]>> clus = k.getCluster();

		List<ClusterSet> clusterset = new ArrayList<ClusterSet>();

		int i = 0;
		for (List<float[]> cl : clus) {
			ClusterSet cs = new ClusterSet();
			cs.setClu(cl);
			cs.setCenter(k.getCenter().get(i));
			clusterset.add(cs);
			i++;
		}

		return clusterset;
	}

	class ClusterSet {
		private float erro;
		private List<float[]> clu;
		private float[] center;

		public float getErro() {
			return erro;
		}

		public void setErro(float erro) {
			this.erro = erro;
		}

		public List<float[]> getClu() {
			return clu;
		}

		public void setClu(List<float[]> clu) {
			this.clu = clu;
		}

		public float[] getCenter() {
			return center;
		}

		public void setCenter(float[] center) {
			this.center = center;
		}

	}
}

package org.algorithm;

import java.util.List;

/**
 * 把计算距离和误差的公式抽离出来
 * @author l0979365428
 *
 */
public class CommonUtil {

	/**
	 * 计算两个点之间的距离
	 * 
	 * @param element
	 *            点1
	 * @param center
	 *            点2
	 * @return 距离
	 */
	public static  float distance(float[] element, float[] center) {
		float distance = 0.0f;
		float x = element[0] - center[0];
		float y = element[1] - center[1];
		float z = x * x + y * y;
		distance = (float) Math.sqrt(z);

		return distance;
	}
	/**
	 * 求两点误差平方的方法
	 * 
	 * @param element
	 *            点1
	 * @param center
	 *            点2
	 * @return 误差平方
	 */
	public static  float errorSquare(float[] element, float[] center) {
		float x = element[0] - center[0];
		float y = element[1] - center[1];

		float errSquare = x * x + y * y;

		return errSquare;
	}
	/**
	 * 计算误差平方和准则函数方法
	 */
	public static  float countRule( List<float[]> cluster,float[] center) {
		float jcF = 0;
	
			for (int j = 0; j < cluster.size(); j++) {
				jcF += CommonUtil.errorSquare(cluster.get(j), center);

			}
		
	return  jcF;
	}
	/**
	 * 打印数据,测试用
	 * 
	 * @param dataArray
	 *            数据集
	 * @param dataArrayName
	 *            数据集名称
	 */
	public static  void printDataArray(List<float[]> dataArray, String dataArrayName) {
		for (int i = 0; i < dataArray.size(); i++) {
			System.out.println("print:" + dataArrayName + "[" + i + "]={"
					+ dataArray.get(i)[0] + "," + dataArray.get(i)[1] + "}");
		}
		System.out.println("===================================");
	}
}

package org.algorithm;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

/**
 * K均值聚类算法
 */
public class Kmeans {
	private int k;// 分成多少簇
	private int m;// 迭代次数
	private int dataSetLength;// 数据集元素个数,即数据集的长度
	private List<float[]> dataSet;// 数据集链表
	private List<float[]> center;// 中心链表
	private List<List<float[]>> cluster; // 簇
	private List<Float> jc;// 误差平方和,k越接近dataSetLength,误差越小
	private Random random;

	public static void main(String[] args) {
		// 初始化一个Kmean对象,将k置为10
		Kmeans k = new Kmeans(5);
		// 初始化试验集
		ArrayList<float[]> dataSet = new ArrayList<float[]>();

		dataSet.add(new float[] { 1, 2 });
		dataSet.add(new float[] { 3, 3 });
		dataSet.add(new float[] { 3, 4 });
		dataSet.add(new float[] { 5, 6 });
		dataSet.add(new float[] { 8, 9 });
		dataSet.add(new float[] { 4, 5 });
		dataSet.add(new float[] { 6, 4 });
		dataSet.add(new float[] { 3, 9 });
		dataSet.add(new float[] { 5, 9 });
		dataSet.add(new float[] { 4, 2 });
		dataSet.add(new float[] { 1, 9 });
		dataSet.add(new float[] { 7, 8 });
		// 设置原始数据集
		k.setDataSet(dataSet);
		// 执行算法
		k.execute();
		// 得到聚类结果
		List<List<float[]>> cluster = k.getCluster();
		// 查看结果
		for (int i = 0; i < cluster.size(); i++) {
			CommonUtil.printDataArray(cluster.get(i), "cluster[" + i + "]");
		}

	}

	/**
	 * 设置需分组的原始数据集
	 * 
	 * @param dataSet
	 */

	public void setDataSet(List<float[]> dataSet) {
		this.dataSet = dataSet;
	}

	/**
	 * 获取结果分组
	 * 
	 * @return 结果集
	 */

	public List<List<float[]>> getCluster() {
		return cluster;
	}

	/**
	 * 构造函数,传入需要分成的簇数量
	 * 
	 * @param k
	 *            簇数量,若k<=0时,设置为1,若k大于数据源的长度时,置为数据源的长度
	 */
	public Kmeans(int k) {
		if (k <= 0) {
			k = 1;
		}
		this.k = k;
	}

	/**
	 * 初始化
	 */
	private void init() {
		m = 0;
		random = new Random();
		if (dataSet == null || dataSet.size() == 0) {
			initDataSet();
		}
		dataSetLength = dataSet.size();
		if (k > dataSetLength) {
			k = dataSetLength;
		}
		center = initCenters();
		cluster = initCluster();
		jc = new ArrayList<Float>();
	}

	/**
	 * 如果调用者未初始化数据集,则采用内部测试数据集
	 */
	private void initDataSet() {
		dataSet = new ArrayList<float[]>();
		// 其中{6,3}是一样的,所以长度为15的数据集分成14簇和15簇的误差都为0
		float[][] dataSetArray = new float[][] { { 8, 2 }, { 3, 4 }, { 2, 5 },
				{ 4, 2 }, { 7, 3 }, { 6, 2 }, { 4, 7 }, { 6, 3 }, { 5, 3 },
				{ 6, 3 }, { 6, 9 }, { 1, 6 }, { 3, 9 }, { 4, 1 }, { 8, 6 } };

		for (int i = 0; i < dataSetArray.length; i++) {
			dataSet.add(dataSetArray[i]);
		}
	}

	/**
	 * 初始化中心数据链表,分成多少簇就有多少个中心点
	 * 
	 * @return 中心点集
	 */
	private ArrayList<float[]> initCenters() {
		ArrayList<float[]> center = new ArrayList<float[]>();
		int[] randoms = new int[k];
		boolean flag;
		int temp = random.nextInt(dataSetLength);
		randoms[0] = temp;
		for (int i = 1; i < k; i++) {
			flag = true;
			while (flag) {
				temp = random.nextInt(dataSetLength);
				int j = 0;

				while (j < i) {
					if (temp == randoms[j]) {
						break;
					}
					j++;
				}
				if (j == i) {
					flag = false;
				}
			}
			randoms[i] = temp;
		}

		for (int i = 0; i < k; i++) {
			center.add(dataSet.get(randoms[i]));// 生成初始化中心链表
		}
		return center;
	}

	/**
	 * 初始化簇集合
	 * 
	 * @return 一个分为k簇的空数据的簇集合
	 */
	private List<List<float[]>> initCluster() {
		List<List<float[]>> cluster = new ArrayList();
		for (int i = 0; i < k; i++) {
			cluster.add(new ArrayList<float[]>());
		}

		return cluster;
	}

	/**
	 * 获取距离集合中最小距离的位置
	 * 
	 * @param distance
	 *            距离数组
	 * @return 最小距离在距离数组中的位置
	 */
	private int minDistance(float[] distance) {
		float minDistance = distance[0];
		int minLocation = 0;
		for (int i = 1; i < distance.length; i++) {
			if (distance[i] < minDistance) {
				minDistance = distance[i];
				minLocation = i;
			} else if (distance[i] == minDistance) // 如果相等,随机返回一个位置
			{
				if (random.nextInt(10) < 5) {
					minLocation = i;
				}
			}
		}

		return minLocation;
	}

	/**
	 * 核心,将当前元素放到最小距离中心相关的簇中
	 */
	private void clusterSet() {
		float[] distance = new float[k];
		for (int i = 0; i < dataSetLength; i++) {
			for (int j = 0; j < k; j++) {
				distance[j] = CommonUtil
						.distance(dataSet.get(i), center.get(j));

			}
			int minLocation = minDistance(distance);

			cluster.get(minLocation).add(dataSet.get(i));// 核心,将当前元素放到最小距离中心相关的簇中

		}
	}

	/**
	 * 计算误差平方和准则函数方法
	 */
	private void countRule() {
		float jcF = 0;
		for (int i = 0; i < cluster.size(); i++) {
			for (int j = 0; j < cluster.get(i).size(); j++) {
				jcF += CommonUtil.errorSquare(cluster.get(i).get(j), center
						.get(i));

			}
		}
		jc.add(jcF);
	}

	/**
	 * 设置新的簇中心方法
	 */
	private void setNewCenter() {
		for (int i = 0; i < k; i++) {
			int n = cluster.get(i).size();
			if (n != 0) {
				float[] newCenter = { 0, 0 };
				for (int j = 0; j < n; j++) {
					newCenter[0] += cluster.get(i).get(j)[0];
					newCenter[1] += cluster.get(i).get(j)[1];
				}
				// 设置一个平均值
				newCenter[0] = newCenter[0] / n;
				newCenter[1] = newCenter[1] / n;
				center.set(i, newCenter);
			}
		}
	}

	public List<float[]> getCenter() {
		return center;
	}

	public void setCenter(List<float[]> center) {
		this.center = center;
	}


	/**
	 * Kmeans算法核心过程方法
	 */
	private void kmeans() {
		init();

		// 循环分组,直到误差不变为止
		while (true) {
			clusterSet();
			countRule();

			if (m != 0) {
				if (jc.get(m) - jc.get(m - 1) == 0) {
					break;
				}
			}

			setNewCenter();

			m++;
			cluster.clear();
			cluster = initCluster();
		}

	}

	/**
	 * 执行算法
	 */
	public void execute() {
		long startTime = System.currentTimeMillis();
		System.out.println("kmeans begins");
		kmeans();
		long endTime = System.currentTimeMillis();
		System.out.println("kmeans running time=" + (endTime - startTime)
				+ "ms");
		System.out.println("kmeans ends");
		System.out.println();
	}
}

分别执行两种聚类算法都使得k=5结果如下:

Kmeans:

print:cluster[0]={5.0,6.0}
print:cluster[1]={4.0,5.0}
print:cluster[2]={6.0,4.0}
===================================
print:cluster[0]={1.0,2.0}
print:cluster[1]={3.0,3.0}
print:cluster[2]={3.0,4.0}
print:cluster[3]={4.0,2.0}
===================================
print:cluster[0]={7.0,8.0}
===================================
print:cluster[0]={8.0,9.0}
===================================
print:cluster[0]={3.0,9.0}
print:cluster[1]={5.0,9.0}
print:cluster[2]={1.0,9.0}
===================================

BisectingKmeans:
print:cluster0[0]={8.0,9.0}
print:cluster0[1]={7.0,8.0}
===================================
print:cluster1[0]={3.0,4.0}
print:cluster1[1]={5.0,6.0}
print:cluster1[2]={4.0,5.0}
print:cluster1[3]={6.0,4.0}
===================================
print:cluster2[0]={1.0,2.0}
print:cluster2[1]={3.0,3.0}
print:cluster2[2]={4.0,2.0}
===================================
print:cluster3[0]={1.0,9.0}
===================================
print:cluster4[0]={3.0,9.0}
print:cluster4[1]={5.0,9.0}
===================================

如上有理解问题还请指正。



参考文献:

http://blog.csdn.net/zouxy09/article/details/17590137

http://wenku.baidu.com/link?url=e6sXeX_txPMnNnYy8W28mP-HSD2Lk8cQGbW-4esipqu95r-P4Ke2QPeHLhfBtoie6agplav6VtVwxlyg-jf_5byHJ_Ce93ARqA6U9rn6XKK

《机器学习实战》

版权声明:本文为博主原创文章,未经博主允许不得转载。

相关内容