diff options
Diffstat (limited to 'bordeaux/learning/stochastic_linear_ranker/java/android/bordeaux/learning/StochasticLinearRanker.java')
-rw-r--r-- | bordeaux/learning/stochastic_linear_ranker/java/android/bordeaux/learning/StochasticLinearRanker.java | 124 |
1 files changed, 88 insertions, 36 deletions
diff --git a/bordeaux/learning/stochastic_linear_ranker/java/android/bordeaux/learning/StochasticLinearRanker.java b/bordeaux/learning/stochastic_linear_ranker/java/android/bordeaux/learning/StochasticLinearRanker.java index 12b7b29ad..35d3ec130 100644 --- a/bordeaux/learning/stochastic_linear_ranker/java/android/bordeaux/learning/StochasticLinearRanker.java +++ b/bordeaux/learning/stochastic_linear_ranker/java/android/bordeaux/learning/StochasticLinearRanker.java @@ -23,6 +23,8 @@ import java.io.Serializable; import java.util.List; import java.util.Arrays; import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; /** * Stochastic Linear Ranker, learns how to rank a sample. The learned rank score @@ -35,19 +37,29 @@ import java.util.ArrayList; */ public class StochasticLinearRanker { String TAG = "StochasticLinearRanker"; - + public static int VAR_NUM = 14; static public class Model implements Serializable { - public ArrayList<String> keys = new ArrayList<String>(); - public ArrayList<Float> values = new ArrayList<Float>(); - public ArrayList<Float> parameters = new ArrayList<Float>(); + public HashMap<String, Float> weights = new HashMap<String, Float>(); + public float weightNormalizer = 1; + public HashMap<String, String> parameters = new HashMap<String, String>(); } - static int VAR_NUM = 15; + /** + * Initializing a ranker + */ public StochasticLinearRanker() { mNativeClassifier = initNativeClassifier(); } /** + * Reset the ranker + */ + public void resetRanker(){ + deleteNativeClassifier(mNativeClassifier); + mNativeClassifier = initNativeClassifier(); + } + + /** * Train the ranker with a pair of samples. A sample, a pair of arrays of * keys and values. The first sample should have higher rank than the second * one. @@ -71,38 +83,71 @@ public class StochasticLinearRanker { /** * Get the current model and parameters of ranker */ - public Model getModel(){ - Model model = new Model(); + public Model getUModel(){ + Model slrModel = new Model(); int len = nativeGetLengthClassifier(mNativeClassifier); - String[] keys = new String[len]; - float[] values = new float[len]; - float[] param = new float[VAR_NUM]; - nativeGetClassifier(keys, values, param, mNativeClassifier); - boolean add_flag; - for (int i=0; i< keys.length ; i++){ - add_flag = model.keys.add(keys[i]); - add_flag = model.values.add(values[i]); - } - for (int i=0; i< param.length ; i++) - add_flag = model.parameters.add(param[i]); - return model; + String[] wKeys = new String[len]; + float[] wValues = new float[len]; + float wNormalizer = 1; + nativeGetWeightClassifier(wKeys, wValues, wNormalizer, mNativeClassifier); + slrModel.weightNormalizer = wNormalizer; + for (int i=0; i< wKeys.length ; i++) + slrModel.weights.put(wKeys[i], wValues[i]); + + String[] paramKeys = new String[VAR_NUM]; + String[] paramValues = new String[VAR_NUM]; + nativeGetParameterClassifier(paramKeys, paramValues, mNativeClassifier); + for (int i=0; i< paramKeys.length ; i++) + slrModel.parameters.put(paramKeys[i], paramValues[i]); + return slrModel; } /** - * use the given model and parameters for ranker + * load the given model and parameters to the ranker */ public boolean loadModel(Model model) { - float[] values = new float[model.values.size()]; - float[] param = new float[model.parameters.size()]; - for (int i = 0; i < model.values.size(); ++i) { - values[i] = model.values.get(i); + String[] wKeys = new String[model.weights.size()]; + float[] wValues = new float[model.weights.size()]; + int i = 0 ; + for (Map.Entry<String, Float> e : model.weights.entrySet()){ + wKeys[i] = e.getKey(); + wValues[i] = e.getValue(); + i++; } - for (int i = 0; i < model.parameters.size(); ++i) { - param[i] = model.parameters.get(i); + boolean res = setModelWeights(wKeys, wValues, model.weightNormalizer); + if (!res) + return false; + + for (Map.Entry<String, String> e : model.parameters.entrySet()){ + res = setModelParameter(e.getKey(), e.getValue()); + if (!res) + return false; } - String[] keys = new String[model.keys.size()]; - model.keys.toArray(keys); - return nativeLoadClassifier(keys, values, param, mNativeClassifier); + return res; + } + + public boolean setModelWeights(String[] keys, float [] values, float normalizer){ + return nativeSetWeightClassifier(keys, values, normalizer, mNativeClassifier); + } + + public boolean setModelParameter(String key, String value){ + boolean res = nativeSetParameterClassifier(key, value, mNativeClassifier); + return res; + } + + /** + * Print a model for debugging + */ + public void print(Model model){ + String Sw = ""; + String Sp = ""; + for (Map.Entry<String, Float> e : model.weights.entrySet()) + Sw = Sw + "<" + e.getKey() + "," + e.getValue() + "> "; + for (Map.Entry<String, String> e : model.parameters.entrySet()) + Sp = Sp + "<" + e.getKey() + "," + e.getValue() + "> "; + Log.i(TAG, "Weights are " + Sw); + Log.i(TAG, "Normalizer is " + model.weightNormalizer); + Log.i(TAG, "Parameters are " + Sp); } @Override @@ -130,12 +175,19 @@ public class StochasticLinearRanker { float[] values_negative, int classifierPtr); - private native float nativeScoreSample(String[] keys, - float[] values, - int classifierPtr); - private native void nativeGetClassifier(String [] keys, float[] values, float[] param, - int classifierPtr); - private native boolean nativeLoadClassifier(String [] keys, float[] values, - float[] param, int classifierPtr); + private native float nativeScoreSample(String[] keys, float[] values, int classifierPtr); + + private native void nativeGetWeightClassifier(String [] keys, float[] values, float normalizer, + int classifierPtr); + + private native void nativeGetParameterClassifier(String [] keys, String[] values, + int classifierPtr); + private native int nativeGetLengthClassifier(int classifierPtr); + + private native boolean nativeSetWeightClassifier(String [] keys, float[] values, + float normalizer, int classifierPtr); + + private native boolean nativeSetParameterClassifier(String key, String value, + int classifierPtr); } |