Skip to content
This repository has been archived by the owner on Oct 8, 2019. It is now read-only.

Commit

Permalink
Fixed a bug in constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
myui committed Mar 9, 2016
1 parent f7a4539 commit bc9626e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 28 deletions.
41 changes: 23 additions & 18 deletions core/src/main/java/hivemall/smile/classification/DecisionTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@
* Some techniques such as bagging, boosting, and random forest use more than one decision tree for
* their analysis.
*/
public class DecisionTree implements Classifier<double[]> {
public final class DecisionTree implements Classifier<double[]> {
/**
* The attributes of independent variable.
*/
Expand Down Expand Up @@ -193,7 +193,7 @@ public static enum SplitRule {
/**
* Classification tree node.
*/
public static class Node implements Externalizable {
public static final class Node implements Externalizable {

/**
* Predicted class label for this node.
Expand Down Expand Up @@ -412,7 +412,7 @@ private static void indent(final StringBuilder builder, final int depth) {
/**
* Classification tree node for training purpose.
*/
class TrainNode implements Comparable<TrainNode> {
private final class TrainNode implements Comparable<TrainNode> {
/**
* The associated regression tree node.
*/
Expand Down Expand Up @@ -464,19 +464,7 @@ public boolean findBestSplit() {

// Sample count in each class.
final int[] count = new int[_k];
int label = -1;
boolean pure = true;
for (int i = 0; i < numSamples; i++) {
int index = bags[i];
int y_i = y[index];
count[y_i]++;

if (label == -1) {
label = y_i;
} else if (y_i != label) {
pure = false;
}
}
final boolean pure = sampleCount(count);

// Since all instances have same label, stop splitting.
if (pure) {
Expand Down Expand Up @@ -513,6 +501,23 @@ public boolean findBestSplit() {
return (node.splitFeature != -1);
}

private boolean sampleCount(@Nonnull final int[] count) {
int label = -1;
boolean pure = true;
for (int i = 0; i < bags.length; i++) {
int index = bags[i];
int y_i = y[index];
count[y_i]++;

if (label == -1) {
label = y_i;
} else if (y_i != label) {
pure = false;
}
}
return pure;
}

/**
* Finds the best split cutoff for attribute j at the current node.
*
Expand Down Expand Up @@ -814,8 +819,8 @@ public DecisionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, @No
this._minSplit = minSplits;
this._minLeafSize = minLeafSize;
this._rule = rule;
this._order = (order == null) ? SmileExtUtils.sort(attributes, x) : order;
this._importance = new double[attributes.length];
this._order = (order == null) ? SmileExtUtils.sort(_attributes, x) : order;
this._importance = new double[_attributes.length];
this._rnd = (rand == null) ? new smile.math.Random() : rand;

final int n = y.length;
Expand Down
18 changes: 8 additions & 10 deletions core/src/main/java/hivemall/smile/regression/RegressionTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
* @see GradientTreeBoost
* @see RandomForest
*/
public class RegressionTree implements Regression<double[]> {
public final class RegressionTree implements Regression<double[]> {
/**
* The attributes of independent variable.
*/
Expand Down Expand Up @@ -167,7 +167,7 @@ public interface NodeOutput {
/**
* Regression tree node.
*/
public static class Node implements Externalizable {
public static final class Node implements Externalizable {

/**
* Predicted real value for this node.
Expand Down Expand Up @@ -396,7 +396,7 @@ private static void indent(final StringBuilder builder, final int depth) {
/**
* Regression tree node for training purpose.
*/
class TrainNode implements Comparable<TrainNode> {
private final class TrainNode implements Comparable<TrainNode> {

/**
* The associated regression tree node.
Expand Down Expand Up @@ -480,14 +480,12 @@ public boolean findBestSplit() {
variables[i] = i;
}

// Loop through features and compute the reduction of squared error,
// which is trueCount * trueMean^2 + falseCount * falseMean^2 - count * parentMean^2
if (_numVars < p) {
// Training of Random Forest will get into this race condition.
// smile.math.Math uses a static object of random number generator.
SmileExtUtils.shuffle(variables, _rnd);
}

// Loop through features and compute the reduction of squared error,
// which is trueCount * trueMean^2 + falseCount * falseMean^2 - count * parentMean^2
final int[] samples = _hasNumericType ? SmileExtUtils.bagsToSamples(bags, x.length)
: null;
for (int j = 0; j < _numVars; j++) {
Expand Down Expand Up @@ -749,7 +747,7 @@ public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x,
checkArgument(x, y, numVars, maxDepth, maxLeafs, minSplits, minLeafSize);

this._attributes = SmileExtUtils.attributeTypes(attributes, x);
if (attributes.length != x[0].length) {
if (_attributes.length != x[0].length) {
throw new IllegalArgumentException("-attrs option is invliad: "
+ Arrays.toString(attributes));
}
Expand All @@ -759,8 +757,8 @@ public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x,
this._maxDepth = maxDepth;
this._minSplit = minSplits;
this._minLeafSize = minLeafSize;
this._order = (order == null) ? SmileExtUtils.sort(attributes, x) : order;
this._importance = new double[attributes.length];
this._order = (order == null) ? SmileExtUtils.sort(_attributes, x) : order;
this._importance = new double[_attributes.length];
this._rnd = (rand == null) ? new smile.math.Random() : rand;
this._nodeOutput = output;

Expand Down

0 comments on commit bc9626e

Please sign in to comment.