Java实现简单版SVM,java实现svm


Java实现简单版SVM
最近的图像分类工作要用到latent svm,为了更加深入了解svm,自己动手实现一个简单版的。
        之所以说是简单版,因为没有用到拉格朗日,对偶,核函数等等。而是用最简单的梯度下降法求解。其中的数学原理我参考了http://blog.csdn.net/lifeitengup/article/details/10951655,文中是用matlab实现的svm。

源代码和数据集下载:https://github.com/linger2012/simpleSvm
其中数据集来自于libsvm,我找了其中一个数据集http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/breast-cancer_scale。 将她分成两部分,训练集和测试集,对应于train_bc和test_bc。
其中测试结果如下:


package com.linger.svm;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.util.StringTokenizer;

public class SimpleSvm 
{
	private int exampleNum;
	private int exampleDim;
	private double[] w;
	private double lambda;
	private double lr = 0.001;//0.00001
	private double threshold = 0.001;
	private double cost;
	private double[] grad;
	private double[] yp;
	public SimpleSvm(double paramLambda)
	{

		lambda = paramLambda;	
		
	}
	
	private void CostAndGrad(double[][] X,double[] y)
	{
		cost =0;
		for(int m=0;m<exampleNum;m++)
		{
			yp[m]=0;
			for(int d=0;d<exampleDim;d++)
			{
				yp[m]+=X[m][d]*w[d];
			}
			
			if(y[m]*yp[m]-1<0)
			{
				cost += (1-y[m]*yp[m]);
			}
			
		}
		
		for(int d=0;d<exampleDim;d++)
		{
			cost += 0.5*lambda*w[d]*w[d];
		}
		

		for(int d=0;d<exampleDim;d++)
		{
			grad[d] = Math.abs(lambda*w[d]);	
			for(int m=0;m<exampleNum;m++)
			{
				if(y[m]*yp[m]-1<0)
				{
					grad[d]-= y[m]*X[m][d];
				}
			}
		}				
	}
	
	private void update()
	{
		for(int d=0;d<exampleDim;d++)
		{
			w[d] -= lr*grad[d];
		}
	}
	
	public void Train(double[][] X,double[] y,int maxIters)
	{
		exampleNum = X.length;
		if(exampleNum <=0) 
		{
			System.out.println("num of example <=0!");
			return;
		}
		exampleDim = X[0].length;
		w = new double[exampleDim];
		grad = new double[exampleDim];
		yp = new double[exampleNum];
		
		for(int iter=0;iter<maxIters;iter++)
		{
			
			CostAndGrad(X,y);
			System.out.println("cost:"+cost);
			if(cost< threshold)
			{
				break;
			}
			update();
			
		}
	}
	private int predict(double[] x)
	{
		double pre=0;
		for(int j=0;j<x.length;j++)
		{
			pre+=x[j]*w[j];
		}
		if(pre >=0)//这个阈值一般位于-1到1
			return 1;
		else return -1;
	}
	
	public void Test(double[][] testX,double[] testY)
	{
		int error=0;
		for(int i=0;i<testX.length;i++)
		{
			if(predict(testX[i]) != testY[i])
			{
				error++;
			}
		}
		System.out.println("total:"+testX.length);
		System.out.println("error:"+error);
		System.out.println("error rate:"+((double)error/testX.length));
		System.out.println("acc rate:"+((double)(testX.length-error)/testX.length));
	}
	
	
	
	public static void loadData(double[][]X,double[] y,String trainFile) throws IOException
	{
		
		File file = new File(trainFile);
		RandomAccessFile raf = new RandomAccessFile(file,"r");
		StringTokenizer tokenizer,tokenizer2; 

		int index=0;
		while(true)
		{
			String line = raf.readLine();
			
			if(line == null) break;
			tokenizer = new StringTokenizer(line," ");
			y[index] = Double.parseDouble(tokenizer.nextToken());
			//System.out.println(y[index]);
			while(tokenizer.hasMoreTokens())
			{
				tokenizer2 = new StringTokenizer(tokenizer.nextToken(),":");
				int k = Integer.parseInt(tokenizer2.nextToken());
				double v = Double.parseDouble(tokenizer2.nextToken());
				X[index][k] = v;
				//System.out.println(k);
				//System.out.println(v);				
			}	
			X[index][0] =1;
			index++;		
		}
	}
	
	public static void main(String[] args) throws IOException 
	{
		// TODO Auto-generated method stub
		double[] y = new double[400];
		double[][] X = new double[400][11];
		String trainFile = "E:\\project\\workspace\\Algorithms\\bin\\train_bc";
		loadData(X,y,trainFile);
		
		
		SimpleSvm svm = new SimpleSvm(0.0001);
		svm.Train(X,y,7000);
		
		double[] test_y = new double[283];
		double[][] test_X = new double[283][11];
		String testFile = "E:\\project\\workspace\\Algorithms\\bin\\test_bc";
		loadData(test_X,test_y,testFile);
		svm.Test(test_X, test_y);
		
	}

}



本文作者:linger 本文链接:http://blog.csdn.net/lingerlanlan/article/details/38688539


svm格式的文件是干的?

SVM是什么?是一种比较新比较流行的算法,常常用在分类问题和回归问题上。
在我这里,我把它当做一种比较好使的分类器。

A.与NN比较起来,参数设定简单。
在NN里面一般需要设定隐藏层节点数,节点连接方式,训练算法等参数,
麻烦而且没办法预知哪些参数比较有效。
而在SVM里面根据核函数的不同一般只需要选择一个参数或者两个参数就可以
搞定。

B.NN容易陷入局部最优,SVM是求解的全局最优。
NN依初始值不同,有些时候会进入局部最优。SVM只会找到全局最优的唯一解。

C.libsvm是一个常用的SVM实现,有C和JAVA等多个版本。
输入输出的文件格式也非常简单,推荐使用。

PS:
只花了两日就搞定SVM,也出乎自己的意料,
哈哈。
不知道是SVM太简单,还是自己太强呢,
哈哈,多半是后者吧。
参考资料:flybart.spaces.live.com/Blog/cns!1pur-Mm-u26bGHLac3tzjZDA!166.entry
 

有没有用java编写的SVM源代码?

不明白,帮你顶!
 

相关内容