8000 Prediction interval related by YibinSun · Pull Request #304 · Waikato/moa · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Prediction interval related #304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package moa.classifiers.predictioninterval;

import com.github.javacliparser.FloatOption;
import com.yahoo.labs.samoa.instances.Instance;
import moa.capabilities.Capabilities;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Regressor;
import moa.core.InstanceExample;
import moa.core.Measurement;
import moa.core.ObjectRepository;
import moa.evaluation.BasicPredictionIntervalEvaluator;
import moa.evaluation.BasicRegressionPerformanceEvaluator;
import moa.learners.Learner;
import moa.options.ClassOption;
import moa.tasks.TaskMonitor;
import org.apache.commons.math3.distribution.NormalDistribution;

public class AdaptivePredictionInterval extends AbstractClassifier implements PredictionIntervalLearner {

public ClassOption learnerOption = new ClassOption("learner", 'l',
"Learner to train.", Regressor.class, "moa.classifiers.meta.SelfOptimisingKNearestLeaves");

public FloatOption confidenceLevelOption = new FloatOption("confidence",'c',"confidence level", 0.95,0,1);

public FloatOption scalarLimitOption = new FloatOption("limit",'t',"lower limit of the scalar", 0.1, 0, 1);

private NormalDistribution normalDistribution;

private Regressor regressor;

private BasicRegressionPerformanceEvaluator evaluator;

private BasicPredictionIntervalEvaluator metaEvaluator;

@Override
public void resetLearningImpl() {
if (this.regressor == null) this.regressor = (Regressor) getPreparedClassOption(this.learnerOption);
if (this.evaluator == null) this.evaluator = new BasicRegressionPerformanceEvaluator();
if (this.metaEvaluator == null) this.metaEvaluator = new BasicPredictionIntervalEvaluator();

}

@Override
public double[] getVotesForInstance(Instance inst) {
InstanceExample example = new InstanceExample(inst);

if (this.regressor == null) this.regressor = (Regressor) getPreparedClassOption(this.learnerOption);
if (this.evaluator == null) this.evaluator = new BasicRegressionPerformanceEvaluator();
if (this.metaEvaluator == null) this.metaEvaluator = new BasicPredictionIntervalEvaluator();

double prediction = ((Learner) this.regressor).getVotesForInstance(example).length > 0 ? ((Learner) this.regressor).getVotesForInstance(example)[0] : 0;
double interval = calculateBounds();


return new double[]{prediction-interval, prediction, prediction + interval};
}

@Override
public void trainOnInstanceImpl(Instance inst) {
InstanceExample example = new InstanceExample(inst);
this.evaluator.addResult(example, new double[]{((Learner) this.regressor).getVotesForInstance(example).length > 0 ? ((Learner) this.regressor).getVotesForInstance(example)[0] : 0});
this.metaEvaluator.addResult(example, this.getVotesForInstance(example));
((Learner)this.regressor).trainOnInstance(example);

}

@Override
public void trainOnInstance(Instance inst) {
trainOnInstanceImpl(inst);
}


private double calculateBounds(){
if (this.normalDistribution == null) this.normalDistribution = new NormalDistribution();
return getScalar(this.metaEvaluator.getCoverage()) * this.normalDistribution.inverseCumulativeProbability(0.5 + this.confidenceLevelOption.getValue() / 2) * this.evaluator.getSquareError();
}

private double getScalar(double coverage) {
double scalar = 100 - coverage;
if( this.evaluator.getTotalWeightObserved() >=100) {
if (coverage >= this.confidenceLevelOption.getValue() * 100.0) {
scalar = (Math.log(-(1 / (100.0 - this.confidenceLevelOption.getValue() * 100)) * (coverage - 100.0)) / Math.log(this.confidenceLevelOption.getValue() * 100) + 1) * (1 - this.scalarLimitOption.getValue()) + this.scalarLimitOption.getValue();
} else if (coverage > 200 * this.confidenceLevelOption.getValue() - 100.0 && coverage < this.confidenceLevelOption.getValue() * 100.0) {
scalar = -(100.0 - coverage) * Math.log(1 / (100.0 - this.confidenceLevelOption.getValue() * 100.0) * (coverage - (200 * this.confidenceLevelOption.getValue() - 100.0))) / Math.log(this.confidenceLevelOption.getValue() * 100) + 1;
}
}else{
scalar = 1.0;
}
if (scalar <= this.scalarLimitOption.getValue())
return this.scalarLimitOption.getValue();
else
return scalar;
}


@Override
public void getDescription(StringBuilder sb, int indent) {

}

@Override
protected Measurement[] getModelMeasurementsImpl() {
return new Measurement[0];
}

@Override
public void getModelDescription(StringBuilder out, int indent) {

}

@Override
public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {

}



@Override
public Capabilities getCapabilities() {
return super.getCapabilities();
}

@Override
public boolean isRandomizable() {
return false;
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package moa.classifiers.predictioninterval;

import com.github.javacliparser.FloatOption;
import com.yahoo.labs.samoa.instances.Instance;
import moa.capabilities.Capabilities;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Regressor;
import moa.core.InstanceExample;
import moa.core.Measurement;
import moa.core.ObjectRepository;
import moa.evaluation.BasicRegressionPerformanceEvaluator;
import moa.learners.Learner;
import moa.options.ClassOption;
import moa.tasks.TaskMonitor;
import org.apache.commons.math3.distribution.NormalDistribution;

public class MVEPredictionInterval extends AbstractClassifier implements PredictionIntervalLearner {

public ClassOption learnerOption = new ClassOption("learner", 'l',
"Learner to train.", Regressor.class, "moa.classifiers.meta.SelfOptimisingKNearestLeaves");

public FloatOption confidenceLevelOption = new FloatOption("confidence",'c',"confidence level", 0.95,0,1);

private NormalDistribution normalDistribution;

private Regressor regressor;

private BasicRegressionPerformanceEvaluator evaluator;

@Override
public void resetLearningImpl() {
if (this.regressor == null) this.regressor = (Regressor) getPreparedClassOption(this.learnerOption);
if (this.evaluator == null) this.evaluator = new BasicRegressionPerformanceEvaluator();

}

@Override
public double[] getVotesForInstance(Instance inst) {
InstanceExample example = new InstanceExample(inst);

if (this.regressor == null) this.regressor = (Regressor) getPreparedClassOption(this.learnerOption);
if (this.evaluator == null) this.evaluator = new BasicRegressionPerformanceEvaluator();

double prediction = ((Learner) this.regressor).getVotesForInstance(example).length > 0 ? ((Learner) this.regressor).getVotesForInstance(example)[0] : 0;
double interval = calculateBounds();


return new double[]{prediction-interval, prediction, prediction + interval};
}

@Override
public void trainOnInstanceImpl(Instance inst) {
InstanceExample example = new InstanceExample(inst);
this.evaluator.addResult(example, new double[]{((Learner) this.regressor).getVotesForInstance(example).length > 0 ? ((Learner) this.regressor).getVotesForInstance(example)[0] : 0});
((Learner)this.regressor).trainOnInstance(example);

}

@Override
public void trainOnInstance(Instance inst) {
trainOnInstanceImpl(inst);
}


private double calculateBounds(){
if (this.normalDistribution == null) this.normalDistribution = new NormalDistribution();
return this.normalDistribution.inverseCumulativeProbability(0.5 + this.confidenceLevelOption.getValue() / 2) * this.evaluator.getSquareError();
}


@Override
public void getDescription(StringBuilder sb, int indent) {

}

@Override
protected Measurement[] getModelMeasurementsImpl() {
return new Measurement[0];
}

@Override
public void getModelDescription(StringBuilder out, int indent) {

}

@Override
public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {

}



@Override
public Capabilities getCapabilities() {
return super.getCapabilities();
}

@Override
public boolean isRandomizable() {
return false;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Classifier.java
* Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
*/
package moa.classifiers.predictioninterval;

import com.yahoo.labs.samoa.instances.Instance;
import moa.core.Example;
import moa.learners.Learner;

/**
* Classifier interface for incremental classification models.
*
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public interface PredictionIntervalLearner extends Learner<Example<Instance>> {

/**
* Gets the classifiers of this ensemble. Returns null if this learner is a
* single learner.
*
* @return an array of the learners of the ensemble
*/
// public PredictionIntervalLearner[] getSubClassifiers();

/**
* Produces a copy of this learner.
*
* @return the copy of this learner
*/
// public PredictionIntervalLearner copy();

/**
* Gets whether this classifier correctly classifies an instance. Uses
* getVotesForInstance to obtain the prediction and the instance to obtain
* its true class.
*
*
* @param inst the instance to be classified
* @return true if the instance is correctly classified
*/

public void trainOnInstance(Instance inst);

/**
* Predicts the class memberships for a given instance. If an instance is
* unclassified, the returned array elements must be all zero.
*
* @param inst the instance to be classified
* @return an array containing the estimated membership probabilities of the
* test instance in each class
*/
public double[] getVotesForInstance(Instance inst);


}
Loading
0