summaryrefslogtreecommitdiff
path: root/bordeaux/learning/predictor_histogram/java/android/bordeaux/learning/HistogramPredictor.java
blob: e63f40dd9322f6b5232eb451e7c7d4c3162a036c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
/*
 * Copyright (C) 2011 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package android.bordeaux.learning;

import android.util.Log;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;
/**
 * A histogram based predictor which records co-occurrences of applations with a speficic
 * feature, for example, location, * time of day, etc. The histogram is kept in a two level
 * hash table. The first level key is the feature value and the second level key is the app
 * id.
 */
// TODOS:
// 1. Use forgetting factor to downweight istances propotional to the time
// 2. Different features could have different weights on prediction scores.
// 3. Add function to remove sampleid (i.e. remove apps that are uninstalled).


public class HistogramPredictor {
    final static String TAG = "HistogramPredictor";

    private HashMap<String, HistogramCounter> mPredictor =
            new HashMap<String, HistogramCounter>();

    private HashMap<String, Integer> mClassCounts = new HashMap<String, Integer>();
    private HashSet<String> mBlacklist = new HashSet<String>();

    private static final int MINIMAL_FEATURE_VALUE_COUNTS = 5;
    private static final int MINIMAL_APP_APPEARANCE_COUNTS = 5;

    // This parameter ranges from 0 to 1 which determines the effect of app prior.
    // When it is set to 0, app prior means completely neglected. When it is set to 1
    // the predictor is a standard naive bayes model.
    private static final int PRIOR_K_VALUE = 1;

    private static final String[] APP_BLACKLIST = {
        "com.android.contacts",
        "com.android.chrome",
        "com.android.providers.downloads.ui",
        "com.android.settings",
        "com.android.vending",
        "com.android.mms",
        "com.google.android.gm",
        "com.google.android.gallery3d",
        "com.google.android.apps.googlevoice",
    };

    public HistogramPredictor(String[] blackList) {
        for (String appName : blackList) {
            mBlacklist.add(appName);
        }
    }

    /*
     * This class keeps the histogram counts for each feature and provide the
     * joint probabilities of <feature, class>.
     */
    private class HistogramCounter {
        private HashMap<String, HashMap<String, Integer> > mCounter =
                new HashMap<String, HashMap<String, Integer> >();

        public HistogramCounter() {
            mCounter.clear();
        }

        public void setCounter(HashMap<String, HashMap<String, Integer> > counter) {
            resetCounter();
            mCounter.putAll(counter);
        }

        public void resetCounter() {
            mCounter.clear();
        }

        public void addSample(String className, String featureValue) {
            HashMap<String, Integer> classCounts;

            if (!mCounter.containsKey(featureValue)) {
                classCounts = new HashMap<String, Integer>();
                mCounter.put(featureValue, classCounts);
            } else {
                classCounts = mCounter.get(featureValue);
            }
            int count = (classCounts.containsKey(className)) ?
                    classCounts.get(className) + 1 : 1;
            classCounts.put(className, count);
        }

        public HashMap<String, Double> getClassScores(String featureValue) {
            HashMap<String, Double> classScores = new HashMap<String, Double>();

            if (mCounter.containsKey(featureValue)) {
                int totalCount = 0;
                for(Map.Entry<String, Integer> entry :
                        mCounter.get(featureValue).entrySet()) {
                    String app = entry.getKey();
                    int count = entry.getValue();

                    // For apps with counts less than or equal to one, we treated
                    // those as having count one. Hence their score, i.e. log(count)
                    // would be zero. classScroes stores only apps with non-zero scores.
                    // Note that totalCount also neglect app with single occurrence.
                    if (count > 1) {
                        double score = Math.log((double) count);
                        classScores.put(app, score);
                        totalCount += count;
                    }
                }
                if (totalCount < MINIMAL_FEATURE_VALUE_COUNTS) {
                    classScores.clear();
                }
            }
            return classScores;
        }

        public byte[] getModel() {
            try {
                ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
                ObjectOutputStream objStream = new ObjectOutputStream(byteStream);
                synchronized(mCounter) {
                    objStream.writeObject(mCounter);
                }
                byte[] bytes = byteStream.toByteArray();
                return bytes;
            } catch (IOException e) {
                throw new RuntimeException("Can't get model");
            }
        }

        public boolean setModel(final byte[] modelData) {
            mCounter.clear();
            HashMap<String, HashMap<String, Integer> > model;

            try {
                ByteArrayInputStream input = new ByteArrayInputStream(modelData);
                ObjectInputStream objStream = new ObjectInputStream(input);
                model = (HashMap<String, HashMap<String, Integer> >) objStream.readObject();
            } catch (IOException e) {
                throw new RuntimeException("Can't load model");
            } catch (ClassNotFoundException e) {
                throw new RuntimeException("Learning class not found");
            }

            synchronized(mCounter) {
                mCounter.putAll(model);
            }

            return true;
        }


        public HashMap<String, HashMap<String, Integer> > getCounter() {
            return mCounter;
        }

        public String toString() {
            String result = "";
            for (Map.Entry<String, HashMap<String, Integer> > entry :
                     mCounter.entrySet()) {
                result += "{ " + entry.getKey() + " : " +
                    entry.getValue().toString() + " }";
            }
            return result;
        }
    }

    /*
     * Given a map of feature name -value pairs returns topK mostly likely apps to
     * be launched with corresponding likelihoods. If topK is set zero, it will return
     * the whole list.
     */
    public List<Map.Entry<String, Double> > findTopClasses(Map<String, String> features, int topK) {
        // Most sophisticated function in this class
        HashMap<String, Double> appScores = new HashMap<String, Double>();
        int validFeatureCount = 0;

        // compute all app scores
        for (Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) {
            String featureName = entry.getKey();
            HistogramCounter counter = entry.getValue();

            if (features.containsKey(featureName)) {
                String featureValue = features.get(featureName);
                HashMap<String, Double> scoreMap = counter.getClassScores(featureValue);

                if (scoreMap.isEmpty()) {
                  continue;
                }
                validFeatureCount++;

                for (Map.Entry<String, Double> item : scoreMap.entrySet()) {
                    String appName = item.getKey();
                    double appScore = item.getValue();
                    if (appScores.containsKey(appName)) {
                        appScore += appScores.get(appName);
                    }
                    appScores.put(appName, appScore);
                }
            }
        }

        HashMap<String, Double> appCandidates = new HashMap<String, Double>();
        for (Map.Entry<String, Double> entry : appScores.entrySet()) {
            String appName = entry.getKey();
            if (mBlacklist.contains(appName)) {
                Log.i(TAG, appName + " is in blacklist");
                continue;
            }
            if (!mClassCounts.containsKey(appName)) {
                throw new RuntimeException("class count error!");
            }
            int appCount = mClassCounts.get(appName);
            if (appCount < MINIMAL_APP_APPEARANCE_COUNTS) {
                Log.i(TAG, appName + " doesn't have enough counts");
                continue;
            }

            double appScore = entry.getValue();
            double appPrior = Math.log((double) appCount);
            appCandidates.put(appName,
                              appScore - appPrior * (validFeatureCount - PRIOR_K_VALUE));
        }

        // sort app scores
        List<Map.Entry<String, Double> > appList =
               new ArrayList<Map.Entry<String, Double> >(appCandidates.size());
        appList.addAll(appCandidates.entrySet());
        Collections.sort(appList, new  Comparator<Map.Entry<String, Double> >() {
            public int compare(Map.Entry<String, Double> o1,
                               Map.Entry<String, Double> o2) {
                return o2.getValue().compareTo(o1.getValue());
            }
        });

        if (topK == 0) {
            topK = appList.size();
        }
        return appList.subList(0, Math.min(topK, appList.size()));
    }

    /*
     * Add a new observation of given sample id and features to the histograms
     */
    public void addSample(String sampleId, Map<String, String> features) {
        for (Map.Entry<String, String> entry : features.entrySet()) {
            String featureName = entry.getKey();
            String featureValue = entry.getValue();

            useFeature(featureName);
            HistogramCounter counter = mPredictor.get(featureName);
            counter.addSample(sampleId, featureValue);
        }

        int sampleCount = (mClassCounts.containsKey(sampleId)) ?
            mClassCounts.get(sampleId) + 1 : 1;
        mClassCounts.put(sampleId, sampleCount);
    }

    /*
     * reset predictor to a empty model
     */
    public void resetPredictor() {
        // TODO: not sure this step would reduce memory waste
        for (HistogramCounter counter : mPredictor.values()) {
            counter.resetCounter();
        }
        mPredictor.clear();
        mClassCounts.clear();
    }

    /*
     * convert the prediction model into a byte array
     */
    public byte[] getModel() {
        // TODO: convert model to a more memory efficient data structure.
        HashMap<String, HashMap<String, HashMap<String, Integer > > > model =
                new HashMap<String, HashMap<String, HashMap<String, Integer > > >();
        for(Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) {
            model.put(entry.getKey(), entry.getValue().getCounter());
        }

        try {
            ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
            ObjectOutputStream objStream = new ObjectOutputStream(byteStream);
            objStream.writeObject(model);
            byte[] bytes = byteStream.toByteArray();
            return bytes;
        } catch (IOException e) {
            throw new RuntimeException("Can't get model");
        }
    }

    /*
     * set the prediction model from a model data in the format of byte array
     */
    public boolean setModel(final byte[] modelData) {
        HashMap<String, HashMap<String, HashMap<String, Integer > > > model;

        try {
            ByteArrayInputStream input = new ByteArrayInputStream(modelData);
            ObjectInputStream objStream = new ObjectInputStream(input);
            model = (HashMap<String, HashMap<String, HashMap<String, Integer > > >)
                    objStream.readObject();
        } catch (IOException e) {
            throw new RuntimeException("Can't load model");
        } catch (ClassNotFoundException e) {
            throw new RuntimeException("Learning class not found");
        }

        resetPredictor();
        for (Map.Entry<String, HashMap<String, HashMap<String, Integer> > > entry :
                model.entrySet()) {
            useFeature(entry.getKey());
            mPredictor.get(entry.getKey()).setCounter(entry.getValue());
        }

        // TODO: this is a temporary fix for now
        loadClassCounter();

        return true;
    }

    private void loadClassCounter() {
        String TIME_OF_WEEK = "Time of Week";

        if (!mPredictor.containsKey(TIME_OF_WEEK)) {
            throw new RuntimeException("Precition model error: missing Time of Week!");
        }

        HashMap<String, HashMap<String, Integer> > counter =
            mPredictor.get(TIME_OF_WEEK).getCounter();

        mClassCounts.clear();
        for (HashMap<String, Integer> map : counter.values()) {
            for (Map.Entry<String, Integer> entry : map.entrySet()) {
                int classCount = entry.getValue();
                String className = entry.getKey();
                // mTotalClassCount += classCount;

                if (mClassCounts.containsKey(className)) {
                    classCount += mClassCounts.get(className);
                }
                mClassCounts.put(className, classCount);
            }
        }
        Log.i(TAG, "class counts: " + mClassCounts);
    }

    private void useFeature(String featureName) {
        if (!mPredictor.containsKey(featureName)) {
            mPredictor.put(featureName, new HistogramCounter());
        }
    }
}