Skip to content

Commit

Permalink
#204: refactored model metrics and fit/residual obs gathering code
Browse files Browse the repository at this point in the history
  • Loading branch information
dbenn committed Dec 19, 2024
1 parent 5cd7ea1 commit 9ab44fc
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@

import org.aavso.tools.vstar.data.DateInfo;
import org.aavso.tools.vstar.data.Magnitude;
import org.aavso.tools.vstar.data.SeriesType;
import org.aavso.tools.vstar.data.ValidObservation;
import org.aavso.tools.vstar.exception.AlgorithmError;
import org.aavso.tools.vstar.plugin.ModelCreatorPluginBase;
import org.aavso.tools.vstar.ui.mediator.AnalysisType;
import org.aavso.tools.vstar.ui.mediator.Mediator;
import org.aavso.tools.vstar.ui.model.plot.ContinuousModelFunction;
import org.aavso.tools.vstar.ui.model.plot.ICoordSource;
Expand All @@ -41,19 +39,19 @@
* This plug-in creates a piecewise linear model from the current means series.
*
* TODO<br/>
* - add base class function to request for obs of particular series vs ask
* whether to open series dialog.<br/>
* - extrema<br/>
* - fit goodness, e.g. RMS, AIC, BIC and refactoring - add base class function
* to request for obs of particular series vs ask whether to open series
* dialog.<br/>
* - function strings<br/>
* - fit goodness, e.g. RMS, AIC, BIC and refactoring
* - derivative (see VeLaModelCreator)<br/>
* - extrema<br/>
* - change to set mean series rather than retrieved from Mediator, e.g. via
* setParams() for AoV plug-in; same for timeCoordSource (e.g. for AoV) => could
* default to current mode means<br/>
* - change PiecewiseLinearFunction to set currLinearFunc, currLinearFuncDeriv
* as part of t > ... check<br/>
* - Disable AoV model button until selection of result plus phase plot
* mode<br/>
* - derivative<br/>
*/
public class PiecewiseLinearMeanSeriesModel extends ModelCreatorPluginBase {

Expand All @@ -65,7 +63,7 @@ public PiecewiseLinearMeanSeriesModel() {

@Override
public AbstractModel getModel(List<ValidObservation> obs) {
return new PiecewiseLinearModelCreator(obs);
return new PiecewiseLinearModel(obs);
}

@Override
Expand Down Expand Up @@ -165,13 +163,17 @@ public UnivariateRealFunction derivative() {
// TODO: see VeLaModelCreator for an example!
return null;
}

public int numberOfFunctions() {
return functions.size();
}
}

class PiecewiseLinearModelCreator extends AbstractModel {
class PiecewiseLinearModel extends AbstractModel {
private List<ValidObservation> meanObs;
private PiecewiseLinearFunction piecewiseFunction;

PiecewiseLinearModelCreator(List<ValidObservation> obs) {
PiecewiseLinearModel(List<ValidObservation> obs) {
super(obs);

// Get the mean observation list for the current mode
Expand All @@ -197,31 +199,12 @@ public void execute() throws AlgorithmError {
double x = timeCoordSource.getXCoord(i, obs);
double y = piecewiseFunction.value(x);

// TODO: need a base class method to collect fit & residual obs

ValidObservation fitOb = new ValidObservation();
fitOb.setDateInfo(new DateInfo(ob.getJD()));
if (Mediator.getInstance().getAnalysisType() == AnalysisType.PHASE_PLOT) {
fitOb.setPreviousCyclePhase(ob.getPreviousCyclePhase());
fitOb.setStandardPhase(ob.getStandardPhase());
}
fitOb.setMagnitude(new Magnitude(y, 0));
fitOb.setBand(SeriesType.Model);
fitOb.setComments(comment);
fit.add(fitOb);

ValidObservation resOb = new ValidObservation();
resOb.setDateInfo(new DateInfo(ob.getJD()));
if (Mediator.getInstance().getAnalysisType() == AnalysisType.PHASE_PLOT) {
resOb.setPreviousCyclePhase(ob.getPreviousCyclePhase());
resOb.setStandardPhase(ob.getStandardPhase());
}
double residual = ob.getMag() - y;
resOb.setMagnitude(new Magnitude(residual, 0));
resOb.setBand(SeriesType.Residuals);
resOb.setComments(comment);
residuals.add(resOb);
collectObs(y, ob, comment);
}

rootMeanSquare();
informationCriteria(piecewiseFunction.numberOfFunctions());
fitMetricsString();
}

@Override
Expand Down Expand Up @@ -268,9 +251,6 @@ private boolean testLinearFunction() {
double m = -0.5;
result &= Tolerance.areClose(m, function.slope(), DELTA, true);
result &= Tolerance.areClose(10 - (m * 2459645), function.c, DELTA, true);
// result &= function.derivative(2459645) == m;
// result &= function.derivative(2459640) == m;
// result &= function.value(2459642) == m * 2459642 + function.c;
result &= Tolerance.areClose(m * 2459642 + function.c, function.value(2459642), DELTA, true);

return result;
Expand All @@ -288,12 +268,10 @@ private boolean testPiecewiseLinearFunction() {

double t1 = obs.get(0).getJD();
LinearFunction function1 = plf.functions.get(0);
// result &= plf.value(t1) == function1.m * t1 + function1.c;
result &= result &= Tolerance.areClose(function1.m * t1 + function1.c, plf.value(t1), DELTA, true);

double t2 = obs.get(1).getJD();
LinearFunction function2 = plf.functions.get(1);
// result &= plf.value(t2) == function2.m * t2 + function2.c;
result &= result &= Tolerance.areClose(function2.m * t2 + function2.c, plf.value(t2), DELTA, true);

return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,10 @@

import org.aavso.tools.vstar.data.DateInfo;
import org.aavso.tools.vstar.data.Magnitude;
import org.aavso.tools.vstar.data.SeriesType;
import org.aavso.tools.vstar.data.ValidObservation;
import org.aavso.tools.vstar.exception.AlgorithmError;
import org.aavso.tools.vstar.plugin.ModelCreatorPluginBase;
import org.aavso.tools.vstar.ui.dialog.PolynomialDegreeDialog;
import org.aavso.tools.vstar.ui.mediator.AnalysisType;
import org.aavso.tools.vstar.ui.mediator.Mediator;
import org.aavso.tools.vstar.ui.model.plot.ContinuousModelFunction;
import org.aavso.tools.vstar.util.ApacheCommonsDerivativeBasedExtremaFinder;
import org.aavso.tools.vstar.util.Tolerance;
Expand Down Expand Up @@ -72,7 +69,7 @@ public String getDisplayName() {

@Override
public AbstractModel getModel(List<ValidObservation> obs) {
return new PolynomialFitCreator(obs);
return new PolynomialFitModel(obs);
}

/**
Expand Down Expand Up @@ -104,14 +101,12 @@ private int getMaxDegree() {
return 30;
}

class PolynomialFitCreator extends AbstractModel {
class PolynomialFitModel extends AbstractModel {
PolynomialFunction function;
PolynomialFitter fitter;
AbstractLeastSquaresOptimizer optimizer;
double aic = Double.NaN;
double bic = Double.NaN;

PolynomialFitCreator(List<ValidObservation> obs) {
PolynomialFitModel(List<ValidObservation> obs) {
super(obs);

int minDegree = getMinDegree();
Expand Down Expand Up @@ -158,23 +153,6 @@ public boolean hasFuncDesc() {
return true;
}

public String toFitMetricsString() throws AlgorithmError {
String strRepr = functionStrMap.get(LocaleProps.get("MODEL_INFO_FIT_METRICS_TITLE"));

if (strRepr == null) {
// Goodness of fit
strRepr = "RMS: " + NumericPrecisionPrefs.formatOther(optimizer.getRMS());

// Akaike and Bayesean Information Criteria
if (aic != Double.NaN && bic != Double.NaN) {
strRepr += "\nAIC: " + NumericPrecisionPrefs.formatOther(aic);
strRepr += "\nBIC: " + NumericPrecisionPrefs.formatOther(bic);
}
}

return strRepr;
}

@Override
public String toString() {
String strRepr = functionStrMap.get(LocaleProps.get("MODEL_INFO_FUNCTION_TITLE"));
Expand Down Expand Up @@ -286,56 +264,23 @@ public void execute() throws AlgorithmError {

fit = new ArrayList<ValidObservation>();
residuals = new ArrayList<ValidObservation>();
double sumSqResiduals = 0;

String comment = LocaleProps.get("MODEL_INFO_POLYNOMIAL_DEGREE_DESC") + degree;

// Create fit and residual observations and
// compute the sum of squares of residuals for
// Akaike and Bayesean Information Criteria.
for (int i = 0; i < obs.size() && !interrupted; i++) {
ValidObservation ob = obs.get(i);

double x = timeCoordSource.getXCoord(i, obs);
double zeroedX = x - zeroPoint;
double y = function.value(zeroedX);

ValidObservation fitOb = new ValidObservation();
fitOb.setDateInfo(new DateInfo(ob.getJD()));
if (Mediator.getInstance().getAnalysisType() == AnalysisType.PHASE_PLOT) {
fitOb.setPreviousCyclePhase(ob.getPreviousCyclePhase());
fitOb.setStandardPhase(ob.getStandardPhase());
}
fitOb.setMagnitude(new Magnitude(y, 0));
fitOb.setBand(SeriesType.Model);
fitOb.setComments(comment);
fit.add(fitOb);

ValidObservation resOb = new ValidObservation();
resOb.setDateInfo(new DateInfo(ob.getJD()));
if (Mediator.getInstance().getAnalysisType() == AnalysisType.PHASE_PLOT) {
resOb.setPreviousCyclePhase(ob.getPreviousCyclePhase());
resOb.setStandardPhase(ob.getStandardPhase());
}
double residual = ob.getMag() - y;
resOb.setMagnitude(new Magnitude(residual, 0));
resOb.setBand(SeriesType.Residuals);
resOb.setComments(comment);
residuals.add(resOb);

sumSqResiduals += (residual * residual);
}

// Fit metrics (AIC, BIC).
int n = residuals.size();
if (n != 0 && sumSqResiduals / n != 0) {
double commonIC = n * Math.log(sumSqResiduals / n);
aic = commonIC + 2 * degree;
bic = commonIC + degree * Math.log(n);
collectObs(y, ob, comment);
}

functionStrMap.put(LocaleProps.get("MODEL_INFO_FIT_METRICS_TITLE"), toFitMetricsString());

rootMeanSquare();
informationCriteria(degree);
fitMetricsString();

ApacheCommonsDerivativeBasedExtremaFinder finder = new ApacheCommonsDerivativeBasedExtremaFinder(
fit, (DifferentiableUnivariateRealFunction) function, timeCoordSource, zeroPoint);

Expand All @@ -347,7 +292,7 @@ public void execute() throws AlgorithmError {
functionStrMap.put(title, extremaStr);
}

// VeLa, Excel, R equations.
// VeLa, Excel, R model functions.
// TODO: consider Python, e.g. for use with
// matplotlib.
functionStrMap.put(LocaleProps.get("MODEL_INFO_FUNCTION_TITLE"), toString());
Expand All @@ -362,6 +307,11 @@ public void execute() throws AlgorithmError {
}
}

@Override
public void rootMeanSquare() {
rms = optimizer.getRMS();
}

@Override
public Map<String, String> getFunctionStrings() {
return functionStrMap;
Expand Down Expand Up @@ -392,31 +342,26 @@ private boolean testPolynomialFit() {

setDegree(9);

PolynomialFitCreator model = (PolynomialFitCreator) getModel(obs);
AbstractModel model = getModel(obs);

try {
model.execute();

double DELTA = 1e-3;
double DELTA = 1e-6;

List<ValidObservation> fit = model.getFit();
ValidObservation fitOb = fit.get(0);
result &= fitOb.getJD() == 2459301.0;
// System.err.println(result);
result &= Tolerance.areClose(0.629248, fitOb.getMag(), DELTA, true);
// System.err.println(result);

List<ValidObservation> residuals = model.getResiduals();
ValidObservation resOb = residuals.get(0);
result &= resOb.getJD() == 2459301.0;
// System.err.println(result);
result &= Tolerance.areClose(0.000073, resOb.getMag(), DELTA, true);
// System.err.println(result);

result &= Tolerance.areClose(-7923.218889035116, model.aic, DELTA, true);
// System.err.println(result);
result &= Tolerance.areClose(-7888.243952752065, model.bic, DELTA, true);
// System.err.println(result);
result &= Tolerance.areClose(0.0000162266724849, model.getRMS(), DELTA, true);
result &= Tolerance.areClose(-7923.218889035116, model.getAIC(), DELTA, true);
result &= Tolerance.areClose(-7888.243952752065, model.getBIC(), DELTA, true);

} catch (AlgorithmError e) {
result = false;
Expand Down
Loading

0 comments on commit 9ab44fc

Please sign in to comment.