summaryrefslogtreecommitdiff
path: root/bordeaux/learning/stochastic_linear_ranker/java/android/bordeaux/learning/StochasticLinearRanker.java
diff options
context:
space:
mode:
Diffstat (limited to 'bordeaux/learning/stochastic_linear_ranker/java/android/bordeaux/learning/StochasticLinearRanker.java')
-rw-r--r--bordeaux/learning/stochastic_linear_ranker/java/android/bordeaux/learning/StochasticLinearRanker.java124
1 files changed, 88 insertions, 36 deletions
diff --git a/bordeaux/learning/stochastic_linear_ranker/java/android/bordeaux/learning/StochasticLinearRanker.java b/bordeaux/learning/stochastic_linear_ranker/java/android/bordeaux/learning/StochasticLinearRanker.java
index 12b7b29ad..35d3ec130 100644
--- a/bordeaux/learning/stochastic_linear_ranker/java/android/bordeaux/learning/StochasticLinearRanker.java
+++ b/bordeaux/learning/stochastic_linear_ranker/java/android/bordeaux/learning/StochasticLinearRanker.java
@@ -23,6 +23,8 @@ import java.io.Serializable;
import java.util.List;
import java.util.Arrays;
import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Map;
/**
* Stochastic Linear Ranker, learns how to rank a sample. The learned rank score
@@ -35,19 +37,29 @@ import java.util.ArrayList;
*/
public class StochasticLinearRanker {
String TAG = "StochasticLinearRanker";
-
+ public static int VAR_NUM = 14;
static public class Model implements Serializable {
- public ArrayList<String> keys = new ArrayList<String>();
- public ArrayList<Float> values = new ArrayList<Float>();
- public ArrayList<Float> parameters = new ArrayList<Float>();
+ public HashMap<String, Float> weights = new HashMap<String, Float>();
+ public float weightNormalizer = 1;
+ public HashMap<String, String> parameters = new HashMap<String, String>();
}
- static int VAR_NUM = 15;
+ /**
+ * Initializing a ranker
+ */
public StochasticLinearRanker() {
mNativeClassifier = initNativeClassifier();
}
/**
+ * Reset the ranker
+ */
+ public void resetRanker(){
+ deleteNativeClassifier(mNativeClassifier);
+ mNativeClassifier = initNativeClassifier();
+ }
+
+ /**
* Train the ranker with a pair of samples. A sample, a pair of arrays of
* keys and values. The first sample should have higher rank than the second
* one.
@@ -71,38 +83,71 @@ public class StochasticLinearRanker {
/**
* Get the current model and parameters of ranker
*/
- public Model getModel(){
- Model model = new Model();
+ public Model getUModel(){
+ Model slrModel = new Model();
int len = nativeGetLengthClassifier(mNativeClassifier);
- String[] keys = new String[len];
- float[] values = new float[len];
- float[] param = new float[VAR_NUM];
- nativeGetClassifier(keys, values, param, mNativeClassifier);
- boolean add_flag;
- for (int i=0; i< keys.length ; i++){
- add_flag = model.keys.add(keys[i]);
- add_flag = model.values.add(values[i]);
- }
- for (int i=0; i< param.length ; i++)
- add_flag = model.parameters.add(param[i]);
- return model;
+ String[] wKeys = new String[len];
+ float[] wValues = new float[len];
+ float wNormalizer = 1;
+ nativeGetWeightClassifier(wKeys, wValues, wNormalizer, mNativeClassifier);
+ slrModel.weightNormalizer = wNormalizer;
+ for (int i=0; i< wKeys.length ; i++)
+ slrModel.weights.put(wKeys[i], wValues[i]);
+
+ String[] paramKeys = new String[VAR_NUM];
+ String[] paramValues = new String[VAR_NUM];
+ nativeGetParameterClassifier(paramKeys, paramValues, mNativeClassifier);
+ for (int i=0; i< paramKeys.length ; i++)
+ slrModel.parameters.put(paramKeys[i], paramValues[i]);
+ return slrModel;
}
/**
- * use the given model and parameters for ranker
+ * load the given model and parameters to the ranker
*/
public boolean loadModel(Model model) {
- float[] values = new float[model.values.size()];
- float[] param = new float[model.parameters.size()];
- for (int i = 0; i < model.values.size(); ++i) {
- values[i] = model.values.get(i);
+ String[] wKeys = new String[model.weights.size()];
+ float[] wValues = new float[model.weights.size()];
+ int i = 0 ;
+ for (Map.Entry<String, Float> e : model.weights.entrySet()){
+ wKeys[i] = e.getKey();
+ wValues[i] = e.getValue();
+ i++;
}
- for (int i = 0; i < model.parameters.size(); ++i) {
- param[i] = model.parameters.get(i);
+ boolean res = setModelWeights(wKeys, wValues, model.weightNormalizer);
+ if (!res)
+ return false;
+
+ for (Map.Entry<String, String> e : model.parameters.entrySet()){
+ res = setModelParameter(e.getKey(), e.getValue());
+ if (!res)
+ return false;
}
- String[] keys = new String[model.keys.size()];
- model.keys.toArray(keys);
- return nativeLoadClassifier(keys, values, param, mNativeClassifier);
+ return res;
+ }
+
+ public boolean setModelWeights(String[] keys, float [] values, float normalizer){
+ return nativeSetWeightClassifier(keys, values, normalizer, mNativeClassifier);
+ }
+
+ public boolean setModelParameter(String key, String value){
+ boolean res = nativeSetParameterClassifier(key, value, mNativeClassifier);
+ return res;
+ }
+
+ /**
+ * Print a model for debugging
+ */
+ public void print(Model model){
+ String Sw = "";
+ String Sp = "";
+ for (Map.Entry<String, Float> e : model.weights.entrySet())
+ Sw = Sw + "<" + e.getKey() + "," + e.getValue() + "> ";
+ for (Map.Entry<String, String> e : model.parameters.entrySet())
+ Sp = Sp + "<" + e.getKey() + "," + e.getValue() + "> ";
+ Log.i(TAG, "Weights are " + Sw);
+ Log.i(TAG, "Normalizer is " + model.weightNormalizer);
+ Log.i(TAG, "Parameters are " + Sp);
}
@Override
@@ -130,12 +175,19 @@ public class StochasticLinearRanker {
float[] values_negative,
int classifierPtr);
- private native float nativeScoreSample(String[] keys,
- float[] values,
- int classifierPtr);
- private native void nativeGetClassifier(String [] keys, float[] values, float[] param,
- int classifierPtr);
- private native boolean nativeLoadClassifier(String [] keys, float[] values,
- float[] param, int classifierPtr);
+ private native float nativeScoreSample(String[] keys, float[] values, int classifierPtr);
+
+ private native void nativeGetWeightClassifier(String [] keys, float[] values, float normalizer,
+ int classifierPtr);
+
+ private native void nativeGetParameterClassifier(String [] keys, String[] values,
+ int classifierPtr);
+
private native int nativeGetLengthClassifier(int classifierPtr);
+
+ private native boolean nativeSetWeightClassifier(String [] keys, float[] values,
+ float normalizer, int classifierPtr);
+
+ private native boolean nativeSetParameterClassifier(String key, String value,
+ int classifierPtr);
}