aboutsummaryrefslogtreecommitdiff
path: root/analysis/cpp/fast_em.cc
diff options
context:
space:
mode:
Diffstat (limited to 'analysis/cpp/fast_em.cc')
-rw-r--r--analysis/cpp/fast_em.cc309
1 files changed, 309 insertions, 0 deletions
diff --git a/analysis/cpp/fast_em.cc b/analysis/cpp/fast_em.cc
new file mode 100644
index 0000000..5bdfedb
--- /dev/null
+++ b/analysis/cpp/fast_em.cc
@@ -0,0 +1,309 @@
+// Copyright 2015 Google Inc. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <assert.h>
+#include <stdarg.h> // va_list, etc.
+#include <stdio.h> // fread()
+#include <stdlib.h> // exit()
+#include <stdint.h> // uint16_t
+#include <string.h> // strcmp()
+#include <cmath> // std::abs operates on doubles
+#include <cstdlib> // strtol
+#include <vector>
+
+using std::vector;
+
+// Log messages to stdout.
+void log(const char* fmt, ...) {
+ va_list args;
+ va_start(args, fmt);
+ vprintf(fmt, args);
+ va_end(args);
+ printf("\n");
+}
+
+const int kTagLen = 4; // 4 byte tags in the file format
+
+bool ExpectTag(FILE* f, const char* tag) {
+ char buf[kTagLen];
+
+ if (fread(buf, sizeof buf[0], kTagLen, f) != kTagLen) {
+ return false;
+ }
+ if (strcmp(buf, tag) != 0) {
+ log("Error: expected '%s'", tag);
+ return false;
+ }
+ return true;
+}
+
+static bool ReadListOfMatrices(
+ FILE* f, uint32_t* num_entries_out, uint32_t* entry_size_out,
+ vector<double>* v_out) {
+ if (!ExpectTag(f, "ne ")) {
+ return false;
+ }
+
+ // R integers are serialized as uint32_t
+ uint32_t num_entries;
+ if (fread(&num_entries, sizeof num_entries, 1, f) != 1) {
+ return false;
+ }
+
+ log("num entries: %d", num_entries);
+
+ if (!ExpectTag(f, "es ")) {
+ return false;
+ }
+
+ uint32_t entry_size;
+ if (fread(&entry_size, sizeof entry_size, 1, f) != 1) {
+ return false;
+ }
+ log("entry_size: %d", entry_size);
+
+ if (!ExpectTag(f, "dat")) {
+ return false;
+ }
+
+ // Now read dynamic data
+ size_t vec_length = num_entries * entry_size;
+
+ vector<double>& v = *v_out;
+ v.resize(vec_length);
+
+ if (fread(&v[0], sizeof v[0], vec_length, f) != vec_length) {
+ return false;
+ }
+
+ // Print out head for sanity
+ size_t n = 20;
+ for (size_t i = 0; i < n && i < v.size(); ++i) {
+ log("%d: %f", i, v[i]);
+ }
+
+ *num_entries_out = num_entries;
+ *entry_size_out = entry_size;
+
+ return true;
+}
+
+void PrintEntryVector(const vector<double>& cond_prob, size_t m,
+ size_t entry_size) {
+ size_t c_base = m * entry_size;
+ log("cond_prob[m = %d] = ", m);
+ for (size_t i = 0; i < entry_size; ++i) {
+ printf("%e ", cond_prob[c_base + i]);
+ }
+ printf("\n");
+}
+
+void PrintPij(const vector<double>& pij) {
+ double sum = 0.0;
+ printf("PIJ:\n");
+ for (size_t i = 0; i < pij.size(); ++i) {
+ printf("%f ", pij[i]);
+ sum += pij[i];
+ }
+ printf("\n");
+ printf("SUM: %f\n", sum); // sum is 1.0 after normalization
+ printf("\n");
+}
+
+// EM algorithm to iteratively estimate parameters.
+
+static int ExpectationMaximization(
+ uint32_t num_entries, uint32_t entry_size, const vector<double>& cond_prob,
+ int max_em_iters, double epsilon, vector<double>* pij_out) {
+ // Start out with uniform distribution.
+ vector<double> pij(entry_size, 0.0);
+ double init = 1.0 / entry_size;
+ for (size_t i = 0; i < pij.size(); ++i) {
+ pij[i] = init;
+ }
+ log("Initialized %d entries with %f", pij.size(), init);
+
+ vector<double> prev_pij(entry_size, 0.0); // pij on previous iteration
+
+ log("Starting up to %d EM iterations", max_em_iters);
+
+ int em_iter = 0; // visible after loop
+ for (; em_iter < max_em_iters; ++em_iter) {
+ //
+ // lapply() step.
+ //
+
+ // Computed below as a function of old Pij and conditional probability for
+ // each report.
+ vector<double> new_pij(entry_size, 0.0);
+
+ // m is the matrix index, giving the conditional probability matrix for a
+ // single report.
+ for (size_t m = 0; m < num_entries; ++m) {
+ vector<double> z(entry_size, 0.0);
+
+ double sum_z = 0.0;
+
+ // base index for the matrix corresponding to a report.
+ size_t c_base = m * entry_size;
+
+ for (size_t i = 0; i < entry_size; ++i) { // multiply and running sum
+ size_t c_index = c_base + i;
+ z[i] = cond_prob[c_index] * pij[i];
+ sum_z += z[i];
+ }
+
+ // Normalize and Reduce("+", wcp) step. These two steps are combined for
+ // memory locality.
+ for (size_t i = 0; i < entry_size; ++i) {
+ new_pij[i] += z[i] / sum_z;
+ }
+ }
+
+ // Divide outside the loop
+ for (size_t i = 0; i < entry_size; ++i) {
+ new_pij[i] /= num_entries;
+ }
+
+ //PrintPij(new_pij);
+
+ //
+ // Check for termination
+ //
+ double max_dif = 0.0;
+ for (size_t i = 0; i < entry_size; ++i) {
+ double dif = std::abs(new_pij[i] - pij[i]);
+ if (dif > max_dif) {
+ max_dif = dif;
+ }
+ }
+
+ pij = new_pij; // copy
+
+ log("fast EM iteration %d, dif = %e", em_iter, max_dif);
+
+ if (max_dif < epsilon) {
+ log("Early EM termination: %e < %e", max_dif, epsilon);
+ break;
+ }
+ }
+
+ *pij_out = pij;
+ // If we reached iteration index 10, then there were 10 iterations: the last
+ // one terminated the loop.
+ return em_iter;
+}
+
+bool WriteTag(const char* tag, FILE* f_out) {
+ assert(strlen(tag) == 3); // write 3 byte tags with NUL byte
+ return fwrite(tag, 1, 4, f_out) == 4;
+}
+
+// Write the probabilities as a flat list of doubles. The caller knows what
+// the dimensions are.
+bool WriteResult(const vector<double>& pij, uint32_t num_em_iters,
+ FILE* f_out) {
+ if (!WriteTag("emi", f_out)) {
+ return false;
+ }
+ if (fwrite(&num_em_iters, sizeof num_em_iters, 1, f_out) != 1) {
+ return false;
+ }
+
+ if (!WriteTag("pij", f_out)) {
+ return false;
+ }
+ size_t n = pij.size();
+ if (fwrite(&pij[0], sizeof pij[0], n, f_out) != n) {
+ return false;
+ }
+ return true;
+}
+
+// Like atoi, but with basic (not exhaustive) error checking.
+bool StringToInt(const char* s, int* result) {
+ bool ok = true;
+ char* end; // mutated by strtol
+
+ *result = strtol(s, &end, 10); // base 10
+ // If strol didn't consume any characters, it failed.
+ if (end == s) {
+ ok = false;
+ }
+ return ok;
+}
+
+int main(int argc, char **argv) {
+ if (argc < 4) {
+ log("Usage: read_numeric INPUT OUTPUT max_em_iters");
+ return 1;
+ }
+
+ char* in_filename = argv[1];
+ char* out_filename = argv[2];
+
+ int max_em_iters;
+ if (!StringToInt(argv[3], &max_em_iters)) {
+ log("Error parsing max_em_iters");
+ return 1;
+ }
+
+ FILE* f = fopen(in_filename, "rb");
+ if (f == NULL) {
+ return 1;
+ }
+
+ // Try opening first so we don't do a long computation and then fail.
+ FILE* f_out = fopen(out_filename, "wb");
+ if (f_out == NULL) {
+ return 1;
+ }
+
+ uint32_t num_entries;
+ uint32_t entry_size;
+ vector<double> cond_prob;
+ if (!ReadListOfMatrices(f, &num_entries, &entry_size, &cond_prob)) {
+ log("Error reading list of matrices");
+ return 1;
+ }
+
+ fclose(f);
+
+ // Sanity check
+ double debug_sum = 0.0;
+ for (size_t m = 0; m < num_entries; ++m) {
+ // base index for the matrix corresponding to a report.
+ size_t c_base = m * entry_size;
+ for (size_t i = 0; i < entry_size; ++i) { // multiply and running sum
+ debug_sum += cond_prob[c_base + i];
+ }
+ }
+ log("Debug sum: %f", debug_sum);
+
+ double epsilon = 1e-6;
+ log("epsilon: %f", epsilon);
+
+ vector<double> pij(entry_size);
+ int num_em_iters = ExpectationMaximization(
+ num_entries, entry_size, cond_prob, max_em_iters, epsilon, &pij);
+
+ if (!WriteResult(pij, num_em_iters, f_out)) {
+ log("Error writing result matrix");
+ return 1;
+ }
+ fclose(f_out);
+
+ log("fast EM done");
+ return 0;
+}