Skip to content

Commit

Permalink
Merge pull request #150 from sanity/addLogistic
Browse files Browse the repository at this point in the history
fixed LogisticRegressionBuilderBug relating to redundant DataTransfor…
  • Loading branch information
athawk81 committed Nov 17, 2015
2 parents fc94e21 + 77fc313 commit 6f6279b
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
be accompanied by a bump in version number, regardless of how minor the change.
0.10.1 -->

<version> 0.10.7</version>
<version> 0.10.9</version>

<repositories>
<repository>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,11 @@ public class LogisticRegressionBuilder<D extends LogisticRegressionDTO<D>> imple
public StandardDataTransformer<D> logisticRegressionDataTransformer;

private ProductFeatureAppender<ClassifierInstance> productFeatureAppender;
private DataTransformer<ClassifierInstance, SparseClassifierInstance, D> dataTransformer;
GradientDescent<SparseClassifierInstance> gradientDescent = new SparseSGD();
private int minWeightForPavBuckets =2;

public LogisticRegressionBuilder(StandardDataTransformer<D> dataTransformer) {
this.dataTransformer = dataTransformer;
this.logisticRegressionDataTransformer = dataTransformer;
}

public LogisticRegressionBuilder<D> productFeatureAppender(ProductFeatureAppender<ClassifierInstance> productFeatureAppender) {
Expand Down Expand Up @@ -67,7 +66,7 @@ public LogisticRegressionBuilder<D> poolAdjacentViolatorsMinWeight(int minWeight

@Override
public D transformData(List<ClassifierInstance> rawInstances){
return dataTransformer.transformData(rawInstances);
return logisticRegressionDataTransformer.transformData(rawInstances);
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@
* Created by alexanderhawk on 10/14/15.
*/
public abstract class StandardDataTransformer<D extends LogisticRegressionDTO<D>> implements DataTransformer<ClassifierInstance, SparseClassifierInstance, D> {
//to do: get label to digit Map and stick in DTO (and transform to logistic regression eventually)
//make LogisticRegressionBuilder use this class and not be tightly coupled to mean normalization (e.g. allow log^2 values)
//make cross validator take a datetransformer (specifically, the Logistic regression PMB, and then do the data normalization
// and set the date time extractor)


/**
* class provides the method: transformInstances, to convert a set of classifier instances into instances that can be processed by
* the LogisticRegressionBuilder.
Expand All @@ -32,9 +26,7 @@ public abstract class StandardDataTransformer<D extends LogisticRegressionDTO<D>
* product feature appendation as well as common co-occurences should be hyper-params within logistic regression.
*
*/
/*Options, wrap logistic regression? in a new logistic regression class that has a logistic reg transformer?
* Or change sparse classifier instance as the the type of Logistic Regression? I almost prefer this. So now to use it...one just passes in a normal list of training instances
*/


protected ProductFeatureAppender<ClassifierInstance> productFeatureAppender;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package quickml.supervised.classifier.logRegression;

import com.google.common.collect.Iterables;
import com.google.common.collect.Maps;
import org.junit.Ignore;
import org.junit.Test;
import org.slf4j.Logger;
Expand All @@ -21,9 +23,16 @@
import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.WeightedAUCCrossValLossFunction;
import quickml.supervised.dataProcessing.instanceTranformer.CommonCoocurrenceProductFeatureAppender;
import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForestBuilder;
import quickml.supervised.predictiveModelOptimizer.FieldValueRecommender;
import quickml.supervised.predictiveModelOptimizer.PredictiveModelOptimizer;
import quickml.supervised.predictiveModelOptimizer.fieldValueRecommenders.FixedOrderRecommender;
import quickml.supervised.tree.decisionTree.DecisionTreeBuilder;

import java.util.List;
import java.util.Map;

import static quickml.supervised.classifier.logisticRegression.LogisticRegressionBuilder.MIN_OBSERVATIONS_OF_ATTRIBUTE;
import static quickml.supervised.classifier.logisticRegression.SparseSGD.*;

/**
* Created by alexanderhawk on 10/13/15.
Expand Down Expand Up @@ -121,4 +130,56 @@ public void testDiabetesInstances() {
logger.info("RF out of time loss: {}", simpleCrossValidator.getLossForModel());
}

@Ignore
@Test
public void optimizerTest(){

List<ClassifierInstance> instances = InstanceLoader.getAdvertisingInstances().subList(0,1000);
CommonCoocurrenceProductFeatureAppender productFeatureAppender = new CommonCoocurrenceProductFeatureAppender<>()
.setMinObservationsOfRawAttribute(35)
.setAllowCategoricalProductFeatures(false)
.setAllowNumericProductFeatures(false)
.setApproximateOverlap(true)
.setMinOverlap(20)
.setIgnoreAttributesCommonToAllInsances(true);

DatedAndMeanNormalizedLogisticRegressionDataTransformer lrdt = new DatedAndMeanNormalizedLogisticRegressionDataTransformer()
.minObservationsOfAttribute(35)
.usingProductFeatures(false)
.productFeatureAppender(productFeatureAppender);

LogisticRegressionBuilder<MeanNormalizedAndDatedLogisticRegressionDTO> logisticRegressionBuilder = new LogisticRegressionBuilder<MeanNormalizedAndDatedLogisticRegressionDTO>(lrdt)
.calibrateWithPoolAdjacentViolators(false)
.gradientDescent(new SparseSGD()
.ridgeRegularizationConstant(0.1)
.learningRate(.0025)
.minibatchSize(1000)
.minEpochs(500)
.maxEpochs(500)
.minPredictedProbablity(1E-3)
.sparseParallelization(true)
);
double start = System.nanoTime();
EnhancedCrossValidator<LogisticRegression, ClassifierInstance, SparseClassifierInstance, MeanNormalizedAndDatedLogisticRegressionDTO> enhancedCrossValidator = new EnhancedCrossValidator<>(logisticRegressionBuilder,
new ClassifierLossChecker(new WeightedAUCCrossValLossFunction(1.0)),
new OutOfTimeDataFactory(0.25, 48), instances);





Map<String, FieldValueRecommender> sgdParams = Maps.newHashMap();
sgdParams.put(RIDGE, new FixedOrderRecommender(.0001));//;, .001, .01, .1, 1));//MonotonicConvergenceRecommender(numTreesList, 0.01));
sgdParams.put(MIN_EPOCHS, new FixedOrderRecommender(8000));// 16000));
sgdParams.put(MAX_EPOCHS, new FixedOrderRecommender(16000));//, 3200));
sgdParams.put(LEARNING_RATE, new FixedOrderRecommender(.0025));//, .001, .005));//11, 14, 16 //Pbest 12
sgdParams.put(MIN_OBSERVATIONS_OF_ATTRIBUTE, new FixedOrderRecommender(20, 50));// 16000));
PredictiveModelOptimizer modelOptimizer = new PredictiveModelOptimizer(sgdParams, enhancedCrossValidator, 2);




logger.info("Optimal sgd parameters: {}", modelOptimizer.determineOptimalConfig());
}

}

0 comments on commit 6f6279b

Please sign in to comment.