Skip to content

Commit

Permalink
OPENNLP-830 Replace the IndexHashMap with java.util.HashMap
Browse files Browse the repository at this point in the history
git-svn-id: https://svn.apache.org/repos/asf/opennlp/trunk@1745401 13f79535-47bb-0310-9956-ffa450edef68
  • Loading branch information
kottmann committed May 24, 2016
1 parent c904f6f commit 04d186f
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.AbstractModelWriter;
import opennlp.tools.ml.model.ComparablePredicate;
import opennlp.tools.ml.model.Context;
import opennlp.tools.ml.model.IndexHashTable;

/**
* Abstract parent class for GISModel writers. It provides the persist method
Expand All @@ -47,13 +47,13 @@ public GISModelWriter(AbstractModel model) {
Object[] data = model.getDataStructures();

PARAMS = (Context[]) data[0];
IndexHashTable<String> pmap = (IndexHashTable<String>) data[1];
Map<String, Integer> pmap = (Map<String, Integer>) data[1];
OUTCOME_LABELS = (String[]) data[2];
CORRECTION_CONSTANT = (Integer) data[3];
CORRECTION_PARAM = (Double) data[4];

PRED_LABELS = new String[pmap.size()];
pmap.toArray(PRED_LABELS);
pmap.keySet().toArray(PRED_LABELS);
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ public int getNumOutcomes() {
return this.outcomeNames.length;
}

private int getPredIndex(String predicate) {
private Integer getPredIndex(String predicate) {

if (predicate == null) throw new RuntimeException("ASDASFAS");
if (pmap == null) throw new RuntimeException("ASDASFAXXXXXXXS");

return pmap.get(predicate);
}

Expand Down Expand Up @@ -64,9 +68,9 @@ private double[] eval(String[] context, float[] values, double[] probs) {
Context[] params = evalParams.getParams();

for (int ci = 0; ci < context.length; ci++) {
int predIdx = getPredIndex(context[ci]);
Integer predIdx = getPredIndex(context[ci]);

if (predIdx >= 0) {
if (predIdx != null) {
double predValue = 1.0;
if (values != null) predValue = values[ci];

Expand Down Expand Up @@ -139,7 +143,8 @@ public boolean equals(Object obj) {
if (this.pmap.size() != objModel.pmap.size())
return false;
String[] pmapArray = new String[pmap.size()];
pmap.toArray(pmapArray);
pmap.keySet().toArray(pmapArray);

for (int i = 0; i < this.pmap.size(); i++) {
if (i != objModel.pmap.get(pmapArray[i]))
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
package opennlp.tools.ml.model;

import java.text.DecimalFormat;
import java.util.HashMap;
import java.util.Map;

public abstract class AbstractModel implements MaxentModel {

/** Mapping between predicates/contexts and an integer representing them. */
protected IndexHashTable<String> pmap;
protected Map<String, Integer> pmap;
/** The names of the outcomes. */
protected String[] outcomeNames;
/** Parameters for the model. */
Expand All @@ -37,7 +39,10 @@ public enum ModelType {Maxent,Perceptron,MaxentQn,NaiveBayes};
/** The type of the model. */
protected ModelType modelType;

public AbstractModel(Context[] params, String[] predLabels, IndexHashTable<String> pmap, String[] outcomeNames) {
public AbstractModel(Context[] params, String[] predLabels, Map<String, Integer> pmap, String[] outcomeNames) {

if (pmap == null) throw new RuntimeException("");

this.pmap = pmap;
this.outcomeNames = outcomeNames;
this.evalParams = new EvalParameters(params,outcomeNames.length);
Expand All @@ -54,7 +59,12 @@ public AbstractModel(Context[] params, String[] predLabels, String[] outcomeName
}

private void init(String[] predLabels, String[] outcomeNames){
this.pmap = new IndexHashTable<String>(predLabels, 0.7d);
this.pmap = new HashMap<String, Integer>(predLabels.length);

for (int i = 0; i < predLabels.length; i++) {
pmap.put(predLabels[i], i);
}

this.outcomeNames = outcomeNames;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
* The table is thread safe and can concurrently accessed by multiple threads,
* thread safety is achieved through immutability. Though its not strictly immutable
* which means, that the table must still be safely published to other threads.
*
* @deprecated use java.util.HashMap instead
*/
@Deprecated
public class IndexHashTable<T> {

private final Object keys[];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.Context;
import opennlp.tools.ml.model.EvalParameters;
import opennlp.tools.ml.model.IndexHashTable;

/**
* Class implementing the multinomial Naive Bayes classifier model.
Expand All @@ -38,19 +37,8 @@ public class NaiveBayesModel extends AbstractModel {
protected double[] outcomeTotals;
protected long vocabulary;

public NaiveBayesModel(Context[] params, String[] predLabels, IndexHashTable<String> pmap, String[] outcomeNames) {
super(params, predLabels, pmap, outcomeNames);
outcomeTotals = initOutcomeTotals(outcomeNames, params);
this.evalParams = new NaiveBayesEvalParameters(params, outcomeNames.length, outcomeTotals, predLabels.length);
modelType = ModelType.NaiveBayes;
}

/**
* @deprecated use the constructor with the {@link IndexHashTable} instead!
*/
@Deprecated
public NaiveBayesModel(Context[] params, String[] predLabels, Map<String, Integer> pmap, String[] outcomeNames) {
super(params, predLabels, outcomeNames);
super(params, predLabels, pmap, outcomeNames);
outcomeTotals = initOutcomeTotals(outcomeNames, params);
this.evalParams = new NaiveBayesEvalParameters(params, outcomeNames.length, outcomeTotals, predLabels.length);
modelType = ModelType.NaiveBayes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.AbstractModelWriter;
import opennlp.tools.ml.model.ComparablePredicate;
import opennlp.tools.ml.model.Context;
import opennlp.tools.ml.model.IndexHashTable;

/**
* Abstract parent class for NaiveBayes writers. It provides the persist method
Expand All @@ -46,11 +46,11 @@ public NaiveBayesModelWriter(AbstractModel model) {
Object[] data = model.getDataStructures();
this.numOutcomes = model.getNumOutcomes();
PARAMS = (Context[]) data[0];
IndexHashTable<String> pmap = (IndexHashTable<String>) data[1];
Map<String, Integer> pmap = (Map<String, Integer>) data[1];
OUTCOME_LABELS = (String[]) data[2];

PRED_LABELS = new String[pmap.size()];
pmap.toArray(PRED_LABELS);
pmap.keySet().toArray(PRED_LABELS);
}

protected ComparablePredicate[] sortValues() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,14 @@
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.Context;
import opennlp.tools.ml.model.EvalParameters;
import opennlp.tools.ml.model.IndexHashTable;

public class PerceptronModel extends AbstractModel {

public PerceptronModel(Context[] params, String[] predLabels, IndexHashTable<String> pmap, String[] outcomeNames) {
public PerceptronModel(Context[] params, String[] predLabels, Map<String, Integer> pmap, String[] outcomeNames) {
super(params,predLabels,pmap,outcomeNames);
modelType = ModelType.Perceptron;
}

/**
* @deprecated use the constructor with the {@link IndexHashTable} instead!
*/
@Deprecated
public PerceptronModel(Context[] params, String[] predLabels, Map<String,Integer> pmap, String[] outcomeNames) {
super(params,predLabels,outcomeNames);
modelType = ModelType.Perceptron;
}

public PerceptronModel(Context[] params, String[] predLabels, String[] outcomeNames) {
super(params,predLabels,outcomeNames);
modelType = ModelType.Perceptron;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.AbstractModelWriter;
import opennlp.tools.ml.model.ComparablePredicate;
import opennlp.tools.ml.model.Context;
import opennlp.tools.ml.model.IndexHashTable;

/**
* Abstract parent class for Perceptron writers. It provides the persist method
Expand All @@ -47,11 +47,11 @@ public PerceptronModelWriter (AbstractModel model) {
Object[] data = model.getDataStructures();
this.numOutcomes = model.getNumOutcomes();
PARAMS = (Context[]) data[0];
IndexHashTable<String> pmap = (IndexHashTable<String>) data[1];
Map<String, Integer> pmap = (Map<String, Integer>) data[1];
OUTCOME_LABELS = (String[])data[2];

PRED_LABELS = new String[pmap.size()];
pmap.toArray(PRED_LABELS);
pmap.keySet().toArray(PRED_LABELS);
}

protected ComparablePredicate[] sortValues () {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.Event;
import opennlp.tools.ml.model.IndexHashTable;
import opennlp.tools.ml.model.MutableContext;
import opennlp.tools.ml.model.OnePassDataIndexer;
import opennlp.tools.ml.model.Sequence;
Expand Down Expand Up @@ -68,7 +67,7 @@ public class SimplePerceptronSequenceTrainer extends AbstractEventModelSequenceT
private MutableContext[] averageParams;

/** Mapping between context and an integer */
private IndexHashTable<String> pmap;
private Map<String, Integer> pmap;

private Map<String,Integer> omap;

Expand Down Expand Up @@ -128,8 +127,12 @@ public AbstractModel trainModel(int iterations, SequenceStream sequenceStream, i

outcomeList = di.getOutcomeList();
predLabels = di.getPredLabels();
pmap = new IndexHashTable<String>(predLabels, 0.7d);
pmap = new HashMap<String, Integer>();

for (int i = 0; i < predLabels.length; i++) {
pmap.put(predLabels[i], i);
}

display("Incorporating indexed data for training... \n");
this.useAverage = useAverage;
numEvents = di.getNumEvents();
Expand Down

0 comments on commit 04d186f

Please sign in to comment.