mahout源码分析之Decision Forest 三部曲之二BuildForest(3)Step1Mapper(2)


Mahout版本:0.7,hadoop版本:1.0.4,jdk:1.7.0_25 64bit。

接上篇,分析到OptIgSplitl类的computeSplit函数里面的numbericalSplit函数,看这个函数的输入参数data和attr,应该是针对data计算出一个和attr相关的值而已。往下看

double[] values = sortedValues(data, attr); ,这一句是干啥的?

private static double[] sortedValues(Data data, int attr) {
    double[] values = data.values(attr);
    Arrays.sort(values);

    return values;
  }
sortedValues就是把data中第attr个属性的值全部取出来,然后排个序(attr从0开始)。比如这次debug得到的三个属性s[5,1,4],第5个属性的值全部排序后得到的值如下:

[0.0, 0.02, 0.03, 0.04, 0.05, 0.06, 0.08, 0.09, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.18, 0.19, 0.23, 0.31, 0.32, 0.33, 0.35, 0.37, 0.38, 0.39, 0.44, 0.45, 0.47, 0.48, 0.49, 0.51, 0.52, 0.53, 0.54, 0.55, 0.56, 0.57, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65, 0.66, 0.67, 0.68, 0.69, 0.72, 0.73, 0.76, 0.81, 0.97, 1.1, 1.46, 1.68, 1.76, 2.7, 6.21]
这里一共有59个值,为什么只有在data.values()函数中把相同的值都去除了,采用的HashSet存储的,这个可以在Data类的第193行看到。

然后到了initCounts(data, values);这个就是初始化三个数组的函数,具体代码如下:

void initCounts(Data data, double[] values) {
    counts = new int[values.length][data.getDataset().nblabels()];
    countAll = new int[data.getDataset().nblabels()];
    countLess = new int[data.getDataset().nblabels()];
  }
这里values.length就是59,data.getDataset().nblabels()就是6;

然后到了computeFrequencies(data, attr, values);这个函数主要计算了两个数组:counts[i][j],其中i表示0-58之间的一个数字,j表示0-5之间的一个数字。因为前面的214个数据删除了重复的才变为了59个,所以原始数据里面肯定是有重复的,这里就是计算这些重复的且它的label值要是一样的,比如原始数据如下(只写一列,因为这里只取了一列):

0.3     0

0.3     0

0.3     1

0.4     1

0.5     2

那么values=[0.3,0.4,0.5],counts[0][0]=2,counts[0][1]=1,counts[1][1]=1,counts[2][2]=1,其他counts的值都为0;countAll就是label的值分别相加,countAll[0]=2,countAll[1]=2,countAll[2]=1;

void computeFrequencies(Data data, int attr, double[] values) {
    Dataset dataset = data.getDataset();

    for (int index = 0; index < data.size(); index++) {
      Instance instance = data.get(index);
      counts[ArrayUtils.indexOf(values, instance.get(attr))][(int) dataset.getLabel(instance)]++;
      countAll[(int) dataset.getLabel(instance)]++;
    }
  }
比如上面的glass.data数据在attr是5的情况下得到的counts(size为59)和countAll(size为6)如下:



继续往下:int size = data.size(); double hy = entropy(countAll, size);这里的size就是214,entropy是啥来的?

private static double entropy(int[] counts, int dataSize) {
    if (dataSize == 0) {
      return 0.0;
    }

    double entropy = 0.0;
    double invDataSize = 1.0 / dataSize;

    for (int count : counts) {
      if (count == 0) {
        continue; // otherwise we get a NaN
      }
      double p = count * invDataSize;
      entropy += -p * Math.log(p) / LOG2;
    }

    return entropy;
  }
这个好像叫做熵的?看下它是如何计算的:


下面的公式中pi是每个label的重复值除以总数214的结果。

在继续往下面看:

 double invDataSize = 1.0 / size;

    int best = -1;
    double bestIg = -1.0;

    // try each possible split value
    for (int index = 0; index < values.length; index++) {
      double ig = hy;

      // instance with attribute value < values[index]
      size = DataUtils.sum(countLess);
      ig -= size * invDataSize * entropy(countLess, size);

      // instance with attribute value >= values[index]
      size = DataUtils.sum(countAll);
      ig -= size * invDataSize * entropy(countAll, size);

      if (ig > bestIg) {
        bestIg = ig;
        best = index;
      }

      DataUtils.add(countLess, counts[index]);
      DataUtils.dec(countAll, counts[index]);
    }
上面算到的hy是2.110138986672679。进入for循环,size = DataUtils.sum(countLess);由于第一次countLess全部值为0,所以size也为0,ig=ig-0;然后size = DataUtils.sum(countAll);这个size值为214;然后就是ig -= size * invDataSize * entropy(countAll, size); size*invDataSize不是1么,entropy(countAll,size)不就是hy么?ig前面不是把hy的值赋给它了么?所以ig=ig-hy=ig-ig=0?然后debug得出的答案是:4.440892098500626E-16。尼玛 ,还10的负十六次方。

然后是最后两行,这个是什么意思?运行前:



运行最后两行代码后,变为:


这样就不用我多说了吧,等于是把counts里面的第一条记录加到countLess中,然后再把countAll中相应的次数减去第一条记录。

下面的就是按照这种规律循环遍历最后得到一个 attr --> bestIg 、bestIndex 的对应关系,然后输出 return new Split(attr, bestIg, values[best]);

今晚就到这里吧,不要熬夜。。。

感觉好像更新小说的样子。。。最近看小说太不给力了,卤煮跟新好慢。不过写这个算法更新更慢。哎,看源码。。。


分享,成长,快乐

转载请注明blog地址:http://blog.csdn.net/fansy1990


相关内容

    暂无相关文章