8000 GH-16524 GLM - control variables - Regression, Binomial by maurever · Pull Request #16601 · h2oai/h2o-3 · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

GH-16524 GLM - control variables - Regression, Binomial #16601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions h2o-algos/src/main/java/hex/DataInfo.java
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are changes in this file really necessary? It seems to me that instead of _adaptedFrameNames I can use _adaptedFrame.names(). It might seem longer and maybe slower but adding one variable that needs to be changed when some other variable (Frame) changes seems to me to be error prone.

Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,26 @@
public class DataInfo extends Keyed<DataInfo> {
public int [] _activeCols;
public Frame _adaptedFrame; // the modified DataInfo frame (columns sorted by largest categorical -> least then all numerical columns)
public String[] _adaptedFrameNames;
public int _responses; // number of responses

public void updateAdaptedFrameNames(){
_adaptedFrameNames = _adaptedFrame.names();
}

public Vec setWeights(String name, Vec vec) {
if(_weights)
return _adaptedFrame.replace(weightChunkId(),vec);
_adaptedFrame.insertVec(weightChunkId(),name,vec);
updateAdaptedFrameNames();
_weights = true;
return null;
}

public void dropWeights() {
if(!_weights)return;
_adaptedFrame.remove(weightChunkId());
updateAdaptedFrameNames();
_weights = false;
}

Expand All @@ -46,6 +53,7 @@ public void dropInteractions() { // only called to cleanup the InteractionWrappe
for(Vec v:vecs)v.remove();
_interactions = null;
}
updateAdaptedFrameNames();
}

public int[] activeCols() {
Expand All @@ -59,6 +67,7 @@ public int[] activeCols() {
public void addResponse(String [] names, Vec[] vecs) {
_adaptedFrame.add(names,vecs);
_responses += vecs.length;
updateAdaptedFrameNames();
}

public int[] catNAFill() {return _catNAFill;}
Expand Down Expand Up @@ -141,10 +150,10 @@ public boolean isSigmaScaled(){
public int weightChunkId(){return _cats + _nums;}
public int outputChunkId() { return outputChunkId(0);}
public int outputChunkId(int n) { return n + _cats + _nums + (_weights?1:0) + (_offset?1:0) + (_fold?1:0) + (_treatment?1:0) + _responses;}
public void addOutput(String name, Vec v) {_adaptedFrame.add(name,v);}
public void addOutput(String name, Vec v) {_adaptedFrame.add(name,v); updateAdaptedFrameNames();}
public Vec getOutputVec(int i) {return _adaptedFrame.vec(outputChunkId(i));}
public void setResponse(String name, Vec v){ setResponse(name,v,0);}
public void setResponse(String name, Vec v, int n){ _adaptedFrame.insertVec(responseChunkId(n),name,v);}
public void setResponse(String name, Vec v, int n){ _adaptedFrame.insertVec(responseChunkId(n),name,v); updateAdaptedFrameNames();}

public final boolean _skipMissing;
public final boolean _imputeMissing;
Expand Down Expand Up @@ -338,6 +347,7 @@ public DataInfo(Frame train, Frame valid, int nResponses, boolean useAllFactorLe
tvecs2[i] = train.vec(i);
}
_adaptedFrame = new Frame(names,tvecs2);
updateAdaptedFrameNames();
train.restructure(names,tvecs2);
if (valid != null)
valid.restructure(names,valid.vecs(names));
Expand Down Expand Up @@ -367,6 +377,7 @@ public DataInfo validDinfo(Frame valid) {
valid = Model.makeInteractions(valid, true, _interactions, _useAllFactorLevels, _skipMissing, false).add(valid);
}
res._adaptedFrame = new Frame(_adaptedFrame.names(),valid.vecs(_adaptedFrame.names()));
res.updateAdaptedFrameNames();
res._valid = true;
return res;
}
Expand Down Expand Up @@ -452,6 +463,7 @@ private DataInfo(DataInfo dinfo,Frame fr, double [] normMul, double [] normSub,
_skipMissing = dinfo._skipMissing;
_imputeMissing = dinfo._imputeMissing;
_adaptedFrame = fr;
updateAdaptedFrameNames();
_catOffsets = MemoryManager.malloc4(catLevels.length + 1);
_catMissing = new boolean[catLevels.length];
Arrays.fill(_catMissing,!(dinfo._imputeMissing || dinfo._skipMissing));
Expand Down Expand Up @@ -1465,6 +1477,7 @@ public DataInfo scoringInfo(String[] names, Frame adaptFrame, int nResponses, bo
res._predictor_transform = TransformType.NONE;
res._response_transform = TransformType.NONE;
res._adaptedFrame = adaptFrame;
res.updateAdaptedFrameNames();
res._weights = _weights && adaptFrame.find(names[weightChunkId()]) != -1;
res._offset = _offset && adaptFrame.find(names[offsetChunkId()]) != -1;
res._fold = _fold && adaptFrame.find(names[foldChunkId()]) != -1;
Expand Down Expand Up @@ -1526,4 +1539,6 @@ public double[] imputeInteraction(String name, InteractionWrappedVec iv, double[
}
}



}
2 changes: 1 addition & 1 deletion h2o-algos/src/main/java/hex/api/MakeGLMModelHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public GLMModelV3 make_model(int version, MakeGLMModelV3 args){
dinfo.setPredictorTransform(TransformType.NONE);
m._output = new GLMOutput(model.dinfo(), model._output._names, model._output._column_types, model._output._domains,
model._output.coefficientNames(), beta, model._output._binomial, model._output._multinomial,
model._output._ordinal);
model._output._ordinal, model._parms._control_variables);
DKV.put(m._key, m);
GLMModelV3 res = new GLMModelV3();
res.fillFromImpl(m);
Expand Down
56 changes: 55 additions & 1 deletion 6D47 h2o-algos/src/main/java/hex/glm/GLM.java
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,7 @@ void restoreFromCheckpoint(TwoDimTable sHist, int[] colIndices) {
}

private transient ScoringHistory _scoringHistory;
private transient ScoringHistory _scoringHistoryControlVariableEnabled;
private transient LambdaSearchScoringHistory _lambdaSearchScoringHistory;

long _t0 = System.currentTimeMillis();
Expand Down Expand Up @@ -946,6 +947,8 @@ public void init(boolean expensive) {
_lambdaSearchScoringHistory = new LambdaSearchScoringHistory(_parms._valid != null,_parms._nfolds > 1);
_scoringHistory = new ScoringHistory(_parms._valid != null,_parms._nfolds > 1,
_parms._generate_scoring_history);
_scoringHistoryControlVariableEnabled = new ScoringHistory(_parms._valid != null,_parms._nfolds > 1,
_parms._generate_scoring_history);
_train.bulkRollups(); // make sure we have all the rollups computed in parallel
_t0 = System.currentTimeMillis();
if ((_parms._lambda_search || !_parms._intercept || _parms._lambda == null || _parms._lambda[0] > 0))
Expand Down Expand Up @@ -1408,8 +1411,13 @@ private void restoreScoringHistoryFromCheckpoint() {
int[] colHeadersIndex = grabHeaderIndex(scoringHistory, num2Copy, colHeaders2Restore);
if (_parms._lambda_search)
_lambdaSearchScoringHistory.restoreFromCheckpoint(scoringHistory, colHeadersIndex);
else
else {
_scoringHistory.restoreFromCheckpoint(scoringHistory, colHeadersIndex);
}
if (_model._parms._control_variables != null) {
TwoDimTable scoringHistoryControlVal = _model._output._control_val_scoring_history;
_scoringHistoryControlVariableEnabled.restoreFromCheckpoint(scoringHistoryControlVal, colHeadersIndex);
}
}

static int[] grabHeaderIndex(TwoDimTable sHist, int numHeaders, String[] colHeadersUseful) {
Expand Down Expand Up @@ -3352,6 +3360,40 @@ private void scoreAndUpdateModel() {
Frame train = DKV.<Frame>getGet(_parms._train); // need to keep this frame to get scoring metrics back
_model.score(_parms.train(), null, CFuncRef.from(_parms._custom_metric_func)).delete();
scorePostProcessing(train, t1);
if(_model._parms._control_variables != null){
_model._useControlVariables = true;
long t2 = System.currentTimeMillis();
_model.score(_parms.train(), null, CFuncRef.from(_parms._custom_metric_func)).delete();
scorePostProcessingControlVal(train, t2);
_model._useControlVariables = false;
}
}

private void scorePostProcessingControlVal(Frame train, long t1) {
ModelMetrics mtrain = ModelMetrics.getFromDKV(_model, train); // updated by model.scoreAndUpdateModel
long t2 = System.currentTimeMillis();
if (!(mtrain == null)) {
_model._output._control_val_training_metrics = mtrain;
_model._output._training_time_ms = t2 - _model._output._start_time; // remember training time
ScoreKeeper trainScore = new ScoreKeeper(Double.NaN);
trainScore.fillFrom(mtrain);
Log.info(LogMsg(mtrain.toString()));
} else {
Log.info(LogMsg("ModelMetrics mtrain is null"));
}
Log.info(LogMsg("Control values Training metrics computed in " + (t2 - t1) + "ms"));
if (_valid != null) {
Frame valid = DKV.<Frame>getGet(_parms._valid);
_model.score(_parms.valid(), null, CFuncRef.from(_parms._custom_metric_func)).delete();
_model._output._control_val_validation_metrics = ModelMetrics.getFromDKV(_model, valid); //updated by model.scoreAndUpdateModel
ScoreKeeper validScore = new ScoreKeeper(Double.NaN);
validScore.fillFrom(_model._output._control_val_validation_metrics);
}
_model.addControlValScoringInfo(_parms, nclasses(), t2, _state._iter); // add to scoringInfo for early stopping
_model._output._control_val_scoring_history = _scoringHistoryControlVariableEnabled.to2dTable(_parms, null,
null);

_model.update(_job._key);
}

private void scorePostProcessing(Frame train, long t1) {
Expand Down Expand Up @@ -3397,6 +3439,9 @@ private void scorePostProcessing(Frame train, long t1) {
ABA3 _scoringHistory.addIterationScore(!(mtrain == null), !(_valid == null), _state._iter, _state.likelihood(),
_state.objective(), _state.deviance(), ((GLMMetrics) _model._output._validation_metrics).residual_deviance(),
mtrain._nobs, _model._output._validation_metrics._nobs, _state.lambda(), _state.alpha());
}
if(_parms._control_variables != null) {

}
} else if (!(mtrain == null)) { // only doing training deviance
if (_parms._lambda_search) {
Expand Down Expand Up @@ -3725,6 +3770,9 @@ private void doCompute() {
_model._finalScoring = true; // enables likelihood calculation while scoring
scoreAndUpdateModel();
_model._finalScoring = false; // avoid calculating likelihood in case of further updates
if (_model._parms._control_variables != null) {
_model._useControlVariables = true;
}

if (dfbetas.equals(_parms._influence))
genRID();
Expand All @@ -3736,6 +3784,12 @@ private void doCompute() {
(null != _parms._valid), false, _model._output.getModelCategory(), false, _parms.hasCustomMetricFunc());
_model._output._scoring_history = combineScoringHistory(_model._output._scoring_history,
scoring_history_early_stop);
if(_model._output._control_val_scoring_history != null) {
TwoDimTable control_val_scoring_history_early_stop = ScoringInfo.createScoringHistoryTable(_model.getControlValScoringInfo(),
(null != _parms._valid), false, _model._output.getModelCategory(), false, _parms.hasCustomMetricFunc());
control_val_scoring_history_early_stop.setTableHeader("Scoring history with control variables enabled");
_model._output._control_val_scoring_history = control_val_scoring_history_early_stop;
}
_model._output._varimp = _model._output.calculateVarimp();
_model._output._variable_importances = calcVarImp(_model._output._varimp);
if (_linearConstraintsOn)
Expand Down
Loading
0