diff options
Diffstat (limited to 'bordeaux/service')
10 files changed, 512 insertions, 133 deletions
diff --git a/bordeaux/service/Android.mk b/bordeaux/service/Android.mk index e8ac4f78f..028308302 100644 --- a/bordeaux/service/Android.mk +++ b/bordeaux/service/Android.mk @@ -32,6 +32,7 @@ LOCAL_SRC_FILES += \ src/android/bordeaux/services/BordeauxClassifier.java \ src/android/bordeaux/services/BordeauxRanker.java \ src/android/bordeaux/services/BordeauxManagerService.java \ + src/android/bordeaux/services/IBordeauxLearner.java \ src/android/bordeaux/services/Learning_StochasticLinearRanker.java \ src/android/bordeaux/services/IBordeauxServiceCallback.aidl \ src/android/bordeaux/services/ILearning_MulticlassPA.aidl \ diff --git a/bordeaux/service/src/android/bordeaux/services/BordeauxManagerService.java b/bordeaux/service/src/android/bordeaux/services/BordeauxManagerService.java index c99e0c50f..e136ca60a 100644 --- a/bordeaux/service/src/android/bordeaux/services/BordeauxManagerService.java +++ b/bordeaux/service/src/android/bordeaux/services/BordeauxManagerService.java @@ -46,8 +46,10 @@ public class BordeauxManagerService { static private ILearning_StochasticLinearRanker mRanker = null; static private ILearning_MulticlassPA mClassifier = null; static private boolean mStarted = false; + public BordeauxManagerService() { } + static private synchronized void bindServices(Context context) { if (mStarted) return; context.bindService(new Intent(IBordeauxService.class.getName()), @@ -55,6 +57,16 @@ public class BordeauxManagerService { mStarted = true; } + + // Call the release, before the Context gets destroyed. + static public synchronized void release(Context context) { + if (mStarted && mConnection != null) { + context.unbindService(mConnection); + mService = null; + mStarted = false; + } + } + static public synchronized IBordeauxService getService(Context context) { if (mService == null) bindServices(context); return mService; @@ -90,6 +102,7 @@ public class BordeauxManagerService { } return mClassifier; } + /** * Class for interacting with the main interface of the service. */ diff --git a/bordeaux/service/src/android/bordeaux/services/BordeauxRanker.java b/bordeaux/service/src/android/bordeaux/services/BordeauxRanker.java index 299054d9c..a0771dc3b 100644 --- a/bordeaux/service/src/android/bordeaux/services/BordeauxRanker.java +++ b/bordeaux/service/src/android/bordeaux/services/BordeauxRanker.java @@ -103,24 +103,12 @@ public class BordeauxRanker { } public void loadModel(String filename) { - if (!retrieveRanker()) - throw new RuntimeException(RANKER_NOTAVAILABLE); - try { - mRanker.LoadModel(filename); - } catch (RemoteException e) { - Log.e(TAG,"Exception: loading model."); - throw new RuntimeException(RANKER_NOTAVAILABLE); - } + // no longer availabe through the interface + return; } public String saveModel(String filename) { - if (!retrieveRanker()) - throw new RuntimeException(RANKER_NOTAVAILABLE); - try { - return mRanker.SaveModel(filename); - } catch (RemoteException e) { - Log.e(TAG,"Exception: saving model."); - throw new RuntimeException(RANKER_NOTAVAILABLE); - } + // no longer availabe through the interface + return null; } } diff --git a/bordeaux/service/src/android/bordeaux/services/BordeauxService.java b/bordeaux/service/src/android/bordeaux/services/BordeauxService.java index b59d1e13b..84a6df070 100644 --- a/bordeaux/service/src/android/bordeaux/services/BordeauxService.java +++ b/bordeaux/service/src/android/bordeaux/services/BordeauxService.java @@ -24,30 +24,25 @@ import android.app.Service; import android.content.ComponentName; import android.content.Context; import android.content.Intent; +import android.content.pm.PackageManager; import android.content.ServiceConnection; import android.os.Bundle; -import android.os.RemoteException; import android.os.Handler; import android.os.IBinder; import android.os.Message; import android.os.Process; import android.os.RemoteCallbackList; +import android.os.RemoteException; import android.view.View; import android.view.View.OnClickListener; import android.widget.Button; import android.widget.TextView; import android.widget.Toast; + import android.bordeaux.R; -import android.bordeaux.learning.MulticlassPA; -import android.bordeaux.learning.StochasticLinearRanker; import android.util.Log; -import java.util.List; -import java.util.ArrayList; -import java.io.*; -import java.util.Scanner; -import java.util.HashMap; -import android.content.pm.PackageManager; +import java.io.*; /** * Machine Learning service that runs in a remote process. @@ -68,16 +63,13 @@ public class BordeauxService extends Service { int mValue = 0; NotificationManager mNotificationManager; - MulticlassPA mMulticlassPA_Learner = null; - - // All saved learning session data - // TODO: backup to the storage - HashMap<String, IBinder> mMulticlassPA_sessions = new HashMap<String, IBinder>(); - HashMap<String, IBinder> mStochasticLinearRanker_sessions = new HashMap<String, IBinder>(); + BordeauxSessionManager mSessionManager; @Override public void onCreate() { + Log.i(TAG, "Bordeaux service created."); mNotificationManager = (NotificationManager)getSystemService(NOTIFICATION_SERVICE); + mSessionManager = new BordeauxSessionManager(this); // Display a notification about us starting. // TODO: don't display the notification after the service is @@ -88,6 +80,9 @@ public class BordeauxService extends Service { @Override public void onDestroy() { + // Save the sessions + mSessionManager.saveSessions(); + // Cancel the persistent notification. mNotificationManager.cancel(R.string.remote_service_started); @@ -96,6 +91,8 @@ public class BordeauxService extends Service { // Unregister all callbacks. mCallbacks.kill(); + + Log.i(TAG, "Bordeaux service stopped."); } @Override @@ -107,39 +104,31 @@ public class BordeauxService extends Service { return null; } + // The main interface implemented by the service. private final IBordeauxService.Stub mBinder = new IBordeauxService.Stub() { - public IBinder getClassifier(String name) { + private IBinder getLearningSession(Class learnerClass, String name) { PackageManager pm = getPackageManager(); String uidname = pm.getNameForUid(getCallingUid()); Log.i(TAG,"Name for uid: " + uidname); - // internal unique key that identifies the learning instance. - // Composed by the unique id of the package plus the user requested - // name. - String key = name + "_MulticlassPA_" + getCallingUid(); - Log.i(TAG, "request classifier session: " + key); - if (mMulticlassPA_sessions.containsKey(key)) { - return mMulticlassPA_sessions.get(key); + BordeauxSessionManager.SessionKey key = + mSessionManager.getSessionKey(uidname, learnerClass, name); + Log.i(TAG, "request learning session: " + key.value); + try { + IBinder iLearner = mSessionManager.getSessionBinder(learnerClass, key); + return iLearner; + } catch (RuntimeException e) { + Log.e(TAG, "Error getting learning interface" + e); + return null; } - IBinder classifier = new Learning_MulticlassPA(); - mMulticlassPA_sessions.put(key, classifier); - Log.i(TAG, "create a new classifier session: " + key); - return classifier; + } + + public IBinder getClassifier(String name) { + return getLearningSession(Learning_MulticlassPA.class, name); } public IBinder getRanker(String name) { - // internal unique key that identifies the learning instance. - // Composed by the unique id of the package plus the user requested - // name. - String key = name + "_Ranker_" + getCallingUid(); - Log.i(TAG, "request ranker session: " + key); - if (mStochasticLinearRanker_sessions.containsKey(key)) { - return mStochasticLinearRanker_sessions.get(key); - } - IBinder ranker = new Learning_StochasticLinearRanker(BordeauxService.this); - mStochasticLinearRanker_sessions.put(key, ranker); - Log.i(TAG, "create a new ranker session: " + key); - return ranker; + return getLearningSession(Learning_StochasticLinearRanker.class, name); } public void registerCallback(IBordeauxServiceCallback cb) { @@ -151,16 +140,6 @@ public class BordeauxService extends Service { } }; - /** - * A MulticlassPA learning interface. - */ - private final ILearning_MulticlassPA.Stub mMulticlassPABinder = new Learning_MulticlassPA(); - /** - * StochasticLinearRanker interface - */ - private final Learning_StochasticLinearRanker mStochasticLinearRankerBinder = new - Learning_StochasticLinearRanker(this); - @Override public void onTaskRemoved(Intent rootIntent) { Toast.makeText(this, "Task removed: " + rootIntent, Toast.LENGTH_LONG).show(); diff --git a/bordeaux/service/src/android/bordeaux/services/BordeauxSessionManager.java b/bordeaux/service/src/android/bordeaux/services/BordeauxSessionManager.java new file mode 100644 index 000000000..89fcac29e --- /dev/null +++ b/bordeaux/service/src/android/bordeaux/services/BordeauxSessionManager.java @@ -0,0 +1,206 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.bordeaux.services; + +import android.bordeaux.services.IBordeauxLearner.ModelChangeCallback; +import android.content.Context; +import android.os.IBinder; +import android.util.Log; + +import java.lang.NoSuchMethodException; +import java.lang.InstantiationException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.HashMap; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +// This class manages the learning sessions from multiple applications. +// The learning sessions are automatically backed up to the storage. +// +class BordeauxSessionManager { + + static private final String TAG = "BordeauxSessionManager"; + private BordeauxSessionStorage mSessionStorage; + + static class Session { + Class learnerClass; + IBordeauxLearner learner; + boolean modified = false; + }; + + static class SessionKey { + String value; + }; + + // Thread to periodically save the sessions to storage + class PeriodicSave extends Thread implements Runnable { + long mSavingInterval = 60000; // 60 seconds + boolean mQuit = false; + PeriodicSave() {} + public void run() { + while (!mQuit) { + try { + sleep(mSavingInterval); + } catch (InterruptedException e) { + // thread waked up. + // ignore + } + saveSessions(); + } + } + } + + PeriodicSave mSavingThread = new PeriodicSave(); + + private ConcurrentHashMap<String, Session> mSessions = + new ConcurrentHashMap<String, Session>(); + + public BordeauxSessionManager(final Context context) { + mSessionStorage = new BordeauxSessionStorage(context); + mSavingThread.start(); + } + + class LearningUpdateCallback implements ModelChangeCallback { + private String mKey; + + public LearningUpdateCallback(String key) { + mKey = key; + } + + public void modelChanged(IBordeauxLearner learner) { + // Save the session + Session session = mSessions.get(mKey); + if (session != null) { + synchronized(session) { + if (session.learner != learner) { + throw new RuntimeException("Session data corrupted!"); + } + session.modified = true; + } + } + } + } + + // internal unique key that identifies the learning instance. + // Composed by the package id of the calling process, learning class name + // and user specified name. + public SessionKey getSessionKey(String callingUid, Class learnerClass, String name) { + SessionKey key = new SessionKey(); + key.value = callingUid + "#" + "_" + name + "_" + learnerClass.getName(); + return key; + } + + public IBinder getSessionBinder(Class learnerClass, SessionKey key) { + if (mSessions.containsKey(key.value)) { + return mSessions.get(key.value).learner.getBinder(); + } + // not in memory cache + try { + // try to find it in the database + Session stored = mSessionStorage.getSession(key.value); + if (stored != null) { + // set the callback, so that we can save the state + stored.learner.setModelChangeCallback(new LearningUpdateCallback(key.value)); + // found session in the storage, put in the cache + mSessions.put(key.value, stored); + return stored.learner.getBinder(); + } + + // if session is not already stored, create a new one. + Log.i(TAG, "create a new learning session: " + key.value); + IBordeauxLearner learner = + (IBordeauxLearner) learnerClass.getConstructor().newInstance(); + // set the callback, so that we can save the state + learner.setModelChangeCallback(new LearningUpdateCallback(key.value)); + Session session = new Session(); + session.learnerClass = learnerClass; + session.learner = learner; + mSessions.put(key.value, session); + return learner.getBinder(); + } catch (Exception e) { + throw new RuntimeException("Can't instantiate class: " + + learnerClass.getName()); + } + } + + public void saveSessions() { + for (Map.Entry<String, Session> session : mSessions.entrySet()) { + synchronized(session) { + // Save the session if it's modified. + if (session.getValue().modified) { + SessionKey skey = new SessionKey(); + skey.value = session.getKey(); + saveSession(skey); + } + } + } + } + + public boolean saveSession(SessionKey key) { + Session session = mSessions.get(key.value); + if (session != null) { + synchronized(session) { + byte[] model = session.learner.getModel(); + + // write to database + boolean res = mSessionStorage.saveSession(key.value, session.learnerClass, model); + if (res) + session.modified = false; + else { + Log.e(TAG, "Can't save session: " + key.value); + } + return res; + } + } + Log.e(TAG, "Session not found: " + key.value); + return false; + } + + // Load all session data into memory. + // The session data will be loaded into the memory from the database, even + // if this method is not called. + public void loadSessions() { + synchronized(mSessions) { + mSessionStorage.getAllSessions(mSessions); + for (Map.Entry<String, Session> session : mSessions.entrySet()) { + // set the callback, so that we can save the state + session.getValue().learner.setModelChangeCallback( + new LearningUpdateCallback(session.getKey())); + } + } + } + + public void removeAllSessionsFromCaller(String callingUid) { + // remove in the hash table + ArrayList<String> remove_keys = new ArrayList<String>(); + for (Map.Entry<String, Session> session : mSessions.entrySet()) { + if (session.getKey().startsWith(callingUid + "#")) { + remove_keys.add(session.getKey()); + } + } + for (String key : remove_keys) { + mSessions.remove(key); + } + // remove all session data from the callingUid in database + // % is used as wild match for the rest of the string in sql + int nDeleted = mSessionStorage.removeSessions(callingUid + "#%"); + if (nDeleted > 0) + Log.i(TAG, "Successfully deleted " + nDeleted + "sessions"); + } +} diff --git a/bordeaux/service/src/android/bordeaux/services/BordeauxSessionStorage.java b/bordeaux/service/src/android/bordeaux/services/BordeauxSessionStorage.java new file mode 100644 index 000000000..89aa370a1 --- /dev/null +++ b/bordeaux/service/src/android/bordeaux/services/BordeauxSessionStorage.java @@ -0,0 +1,146 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.bordeaux.services; + +import android.content.ContentValues; +import android.content.Context; +import android.database.Cursor; +import android.database.SQLException; +import android.database.sqlite.SQLiteDatabase; +import android.database.sqlite.SQLiteOpenHelper; +import android.util.Log; + +import java.lang.System; +import java.util.concurrent.ConcurrentHashMap; + +// This class manages the database for storing the session data. +// +class BordeauxSessionStorage { + + private static final String TAG = "BordeauxSessionStorage"; + // unique key for the session + public static final String COLUMN_KEY = "key"; + // name of the learning class + public static final String COLUMN_CLASS = "class"; + // data of the learning model + public static final String COLUMN_MODEL = "model"; + // last update time + public static final String COLUMN_TIME = "time"; + + private static final String DATABASE_NAME = "bordeaux"; + private static final String SESSION_TABLE = "sessions"; + private static final int DATABASE_VERSION = 1; + private static final String DATABASE_CREATE = + "create table " + SESSION_TABLE + "( " + COLUMN_KEY + + " TEXT primary key, " + COLUMN_CLASS + " TEXT, " + + COLUMN_MODEL + " BLOB, " + COLUMN_TIME + " INTEGER);"; + + private SessionDBHelper mDbHelper; + private SQLiteDatabase mDbSessions; + + BordeauxSessionStorage(final Context context) { + try { + mDbHelper = new SessionDBHelper(context); + mDbSessions = mDbHelper.getWritableDatabase(); + } catch (SQLException e) { + throw new RuntimeException("Can't open session database"); + } + } + + private class SessionDBHelper extends SQLiteOpenHelper { + SessionDBHelper(Context context) { + super(context, DATABASE_NAME, null, DATABASE_VERSION); + } + + @Override + public void onCreate(SQLiteDatabase db) { + db.execSQL(DATABASE_CREATE); + } + + @Override + public void onUpgrade(SQLiteDatabase db, int oldVersion, int newVersion) { + Log.w(TAG, "Upgrading database from version " + oldVersion + " to " + + newVersion + ", which will destroy all old data"); + + db.execSQL("DROP TABLE IF EXISTS " + SESSION_TABLE); + onCreate(db); + } + } + + private ContentValues createSessionEntry(String key, Class learner, byte[] model) { + ContentValues entry = new ContentValues(); + entry.put(COLUMN_KEY, key); + entry.put(COLUMN_TIME, System.currentTimeMillis()); + entry.put(COLUMN_MODEL, model); + entry.put(COLUMN_CLASS, learner.getName()); + return entry; + } + + boolean saveSession(String key, Class learner, byte[] model) { + ContentValues content = createSessionEntry(key, learner, model); + long rowID = + mDbSessions.insertWithOnConflict(SESSION_TABLE, null, content, + SQLiteDatabase.CONFLICT_REPLACE); + return rowID >= 0; + } + + private BordeauxSessionManager.Session getSessionFromCursor(Cursor cursor) { + BordeauxSessionManager.Session session = new BordeauxSessionManager.Session(); + String className = cursor.getString(cursor.getColumnIndex(COLUMN_CLASS)); + try { + session.learnerClass = Class.forName(className); + session.learner = (IBordeauxLearner) session.learnerClass.getConstructor().newInstance(); + } catch (Exception e) { + throw new RuntimeException("Can't instantiate class: " + className); + } + byte[] model = cursor.getBlob(cursor.getColumnIndex(COLUMN_MODEL)); + session.learner.setModel(model); + return session; + } + + BordeauxSessionManager.Session getSession(String key) { + Cursor cursor = mDbSessions.query(true, SESSION_TABLE, + new String[]{COLUMN_KEY, COLUMN_CLASS, COLUMN_MODEL, COLUMN_TIME}, + COLUMN_KEY + "=\"" + key + "\"", null, null, null, null, null); + if (cursor == null) return null; + if (cursor.getCount() == 0) return null; + if (cursor.getCount() > 1) { + throw new RuntimeException("Unexpected duplication in session table for key:" + key); + } + cursor.moveToFirst(); + return getSessionFromCursor(cursor); + } + + void getAllSessions(ConcurrentHashMap<String, BordeauxSessionManager.Session> sessions) { + Cursor cursor = mDbSessions.rawQuery("select * from ?;", new String[]{SESSION_TABLE}); + if (cursor == null) return; + do { + String key = cursor.getString(cursor.getColumnIndex(COLUMN_KEY)); + BordeauxSessionManager.Session session = getSessionFromCursor(cursor); + sessions.put(key, session); + } while (cursor.moveToNext()); + } + + // remove all sessions that have the key that matches the given sql regular + // expression. + int removeSessions(String reKey) { + int nDeleteRows = mDbSessions.delete(SESSION_TABLE, "? like \"?\"", + new String[]{COLUMN_KEY, reKey}); + Log.i(TAG, "Number of rows in session table deleted: " + nDeleteRows); + return nDeleteRows; + } +} diff --git a/bordeaux/service/src/android/bordeaux/services/IBordeauxLearner.java b/bordeaux/service/src/android/bordeaux/services/IBordeauxLearner.java new file mode 100644 index 000000000..114d29410 --- /dev/null +++ b/bordeaux/service/src/android/bordeaux/services/IBordeauxLearner.java @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.bordeaux.services; + +import android.os.IBinder; + +interface IBordeauxLearner { + + interface ModelChangeCallback { + + public void modelChanged(IBordeauxLearner learner); + + } + + public byte [] getModel(); + + public boolean setModel(final byte [] modelData); + + public IBinder getBinder(); + + // call back for the learner model change + public void setModelChangeCallback(ModelChangeCallback callback); +} diff --git a/bordeaux/service/src/android/bordeaux/services/ILearning_StochasticLinearRanker.aidl b/bordeaux/service/src/android/bordeaux/services/ILearning_StochasticLinearRanker.aidl index 912e45644..b0bb5c1b5 100644 --- a/bordeaux/service/src/android/bordeaux/services/ILearning_StochasticLinearRanker.aidl +++ b/bordeaux/service/src/android/bordeaux/services/ILearning_StochasticLinearRanker.aidl @@ -28,7 +28,5 @@ interface ILearning_StochasticLinearRanker { boolean UpdateClassifier(in List<StringFloat> sample_1, in List<StringFloat> sample_2); float ScoreSample(in List<StringFloat> sample); - String SaveModel(in String filename ); - void LoadModel(in String filename); } diff --git a/bordeaux/service/src/android/bordeaux/services/Learning_MulticlassPA.java b/bordeaux/service/src/android/bordeaux/services/Learning_MulticlassPA.java index 8d508aa10..438398d0f 100644 --- a/bordeaux/service/src/android/bordeaux/services/Learning_MulticlassPA.java +++ b/bordeaux/service/src/android/bordeaux/services/Learning_MulticlassPA.java @@ -17,11 +17,15 @@ package android.bordeaux.services; import android.bordeaux.learning.MulticlassPA; +import android.os.IBinder; + import java.util.List; import java.util.ArrayList; -public class Learning_MulticlassPA extends ILearning_MulticlassPA.Stub { +public class Learning_MulticlassPA extends ILearning_MulticlassPA.Stub + implements IBordeauxLearner { private MulticlassPA mMulticlassPA_learner; + private ModelChangeCallback modelChangeCallback = null; class IntFloatArray { int[] indexArray; @@ -44,6 +48,24 @@ public class Learning_MulticlassPA extends ILearning_MulticlassPA.Stub { mMulticlassPA_learner = new MulticlassPA(2, 2, 0.001f); } + // Beginning of the IBordeauxLearner Interface implementation + public byte [] getModel() { + return null; + } + + public boolean setModel(final byte [] modelData) { + return false; + } + + public IBinder getBinder() { + return this; + } + + public void setModelChangeCallback(ModelChangeCallback callback) { + modelChangeCallback = callback; + } + // End of IBordeauxLearner Interface implemenation + // This implementation, combines training and prediction in one step. // The return value is the prediction value for the supplied sample. It // also update the model with the current sample. @@ -52,6 +74,9 @@ public class Learning_MulticlassPA extends ILearning_MulticlassPA.Stub { mMulticlassPA_learner.sparseTrainOneExample(splited.indexArray, splited.floatArray, target); + if (modelChangeCallback != null) { + modelChangeCallback.modelChanged(this); + } } public int Classify(List<IntFloat> sample) { diff --git a/bordeaux/service/src/android/bordeaux/services/Learning_StochasticLinearRanker.java b/bordeaux/service/src/android/bordeaux/services/Learning_StochasticLinearRanker.java index 6b0479794..bb626584a 100644 --- a/bordeaux/service/src/android/bordeaux/services/Learning_StochasticLinearRanker.java +++ b/bordeaux/service/src/android/bordeaux/services/Learning_StochasticLinearRanker.java @@ -16,22 +16,27 @@ package android.bordeaux.services; -import android.content.Context; import android.bordeaux.learning.StochasticLinearRanker; +import android.bordeaux.learning.StochasticLinearRanker.Model; +import android.bordeaux.services.IBordeauxLearner.ModelChangeCallback; +import android.os.IBinder; import android.util.Log; -import java.util.List; -import java.util.ArrayList; + import java.io.*; +import java.lang.ClassNotFoundException; +import java.util.Arrays; +import java.util.ArrayList; +import java.util.List; import java.util.Scanner; -public class Learning_StochasticLinearRanker extends ILearning_StochasticLinearRanker.Stub { +public class Learning_StochasticLinearRanker extends ILearning_StochasticLinearRanker.Stub + implements IBordeauxLearner { String TAG = "ILearning_StochasticLinearRanker"; - Context mContext; private StochasticLinearRanker mLearningSlRanker = null; + private ModelChangeCallback modelChangeCallback = null; - public Learning_StochasticLinearRanker(Context context){ - mContext = context; + public Learning_StochasticLinearRanker(){ } public boolean UpdateClassifier(List<StringFloat> sample_1, List<StringFloat> sample_2){ @@ -51,6 +56,9 @@ public class Learning_StochasticLinearRanker extends ILearning_StochasticLinearR } if (mLearningSlRanker == null) mLearningSlRanker = new StochasticLinearRanker(); boolean res = mLearningSlRanker.updateClassifier(keys_1,values_1,keys_2,values_2); + if (res && modelChangeCallback != null) { + modelChangeCallback.modelChanged(this); + } return res; } @@ -68,67 +76,45 @@ public class Learning_StochasticLinearRanker extends ILearning_StochasticLinearR return res; } - public void LoadModel(String FileName){ - try{ - String str = ""; - StringBuffer buf = new StringBuffer(); - FileInputStream fis = mContext.openFileInput(FileName); - BufferedReader reader = new BufferedReader(new InputStreamReader(fis)); - if (fis!=null) { - while ((str = reader.readLine()) != null) { - buf.append(str + "\n" ); - } - } - fis.close(); - String Temps = buf.toString(); - String[] TempS_Array; - TempS_Array = Temps.split("<>"); - String KeyValueString = TempS_Array[0]; - String ParamString = TempS_Array[1]; - String[] TempS1_Array; - TempS1_Array = KeyValueString.split("\\|"); - int len = TempS1_Array.length; - String[] keys = new String[len]; - float[] values = new float[len]; - for (int i =0; i< len; i++ ){ - String[] TempSd_Array; - TempSd_Array = TempS1_Array[i].split(","); - keys[i] = TempSd_Array[0]; - values[i] = Float.valueOf(TempSd_Array[1].trim()).floatValue(); - } - String[] TempS2_Array; - TempS2_Array = ParamString.split("\\|"); - int lenParam = TempS2_Array.length - 1; - float[] parameters = new float[lenParam]; - for (int i =0; i< lenParam; i++ ){ - parameters[i] = Float.valueOf(TempS2_Array[i].trim()).floatValue(); - } - if (mLearningSlRanker == null) mLearningSlRanker = new StochasticLinearRanker(); - boolean res = mLearningSlRanker.loadModel(keys,values, parameters); - - } catch (IOException e){ + // Beginning of the IBordeauxLearner Interface implementation + public byte [] getModel() { + if (mLearningSlRanker == null) mLearningSlRanker = new StochasticLinearRanker(); + Model model = mLearningSlRanker.getModel(); + try { + ByteArrayOutputStream byteStream = new ByteArrayOutputStream(); + ObjectOutputStream objStream = new ObjectOutputStream(byteStream); + objStream.writeObject(model); + //return byteStream.toByteArray(); + byte[] bytes = byteStream.toByteArray(); + Log.i(TAG, "getModel: " + bytes); + return bytes; + } catch (IOException e) { + throw new RuntimeException("Can't get model"); } } - public String SaveModel(String FileName){ - ArrayList<String> keys_list = new ArrayList<String>(); - ArrayList<Float> values_list = new ArrayList<Float>(); - ArrayList<Float> parameters_list = new ArrayList<Float>(); - if (mLearningSlRanker == null) mLearningSlRanker = new StochasticLinearRanker(); - mLearningSlRanker.getModel(keys_list,values_list, parameters_list); - String S_model = ""; - for (int i = 0; i < keys_list.size(); i++) - S_model = S_model + keys_list.get(i) + "," + values_list.get(i) + "|"; - String S_param =""; - for (int i=0; i< parameters_list.size(); i++) - S_param = S_param + parameters_list.get(i) + "|"; - String Final_Str = S_model + "<> " + S_param; - try{ - FileOutputStream fos = mContext.openFileOutput(FileName, Context.MODE_PRIVATE); - fos.write(Final_Str.getBytes()); - fos.close(); - } catch (IOException e){ + public boolean setModel(final byte [] modelData) { + try { + ByteArrayInputStream input = new ByteArrayInputStream(modelData); + ObjectInputStream objStream = new ObjectInputStream(input); + Model model = (Model) objStream.readObject(); + if (mLearningSlRanker == null) mLearningSlRanker = new StochasticLinearRanker(); + boolean res = mLearningSlRanker.loadModel(model); + Log.i(TAG, "LoadModel: " + modelData); + return res; + } catch (IOException e) { + throw new RuntimeException("Can't load model"); + } catch (ClassNotFoundException e) { + throw new RuntimeException("Learning class not found"); } - return S_model; } + + public IBinder getBinder() { + return this; + } + + public void setModelChangeCallback(ModelChangeCallback callback) { + modelChangeCallback = callback; + } + // End of IBordeauxLearner Interface implemenation } |