aboutsummaryrefslogtreecommitdiff
path: root/analysis/tensorflow/fast_em.py
blob: ea001e44efe18f06c7050ab4e088cd0abdf23129 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#!/usr/bin/python
"""
fast_em.py: Tensorflow implementation of expectation maximization for RAPPOR
association analysis.

TODO:
  - Use TensorFlow ops for reading input (so that reading input can be
    distributed)
  - Reduce the number of ops (currently proportional to the number of reports).
    May require new TensorFlow ops.
  - Fix performance bug (v_split is probably being recomputed on every
    iteration):
    bin$ ./test.sh decode-assoc-cpp - 1.1 seconds (single-threaded C++)
    bin$ ./test.sh decode-assoc-tensorflow - 226 seconds on GPU
"""

import sys

import numpy as np
import tensorflow as tf


def log(msg, *args):
  if args:
    msg = msg % args
  print >>sys.stderr, msg


def ExpectTag(f, expected):
  """Read and consume a 4 byte tag from the given file."""
  b = f.read(4)
  if b != expected:
    raise RuntimeError('Expected %r, got %r' % (expected, b))


def ReadListOfMatrices(f):
  """
  Read a big list of conditional probability matrices from a binary file.
  """
  ExpectTag(f, 'ne \0')
  num_entries = np.fromfile(f, np.uint32, count=1)[0]
  log('Number of entries: %d', num_entries)

  ExpectTag(f, 'es \0')
  entry_size = np.fromfile(f, np.uint32, count=1)[0]
  log('Entry size: %d', entry_size)

  ExpectTag(f, 'dat\0')
  vec_length = num_entries * entry_size
  v = np.fromfile(f, np.float64, count=vec_length)

  log('Values read: %d', len(v))
  log('v: %s', v[:10])
  #print 'SUM', sum(v)

  # NOTE: We're not reshaping because we're using one TensorFlow tensor object
  # per matrix, since it makes the algorithm expressible with current
  # TensorFlow ops.
  #v = v.reshape((num_entries, entry_size))

  return num_entries, entry_size, v


def WriteTag(f, tag):
  if len(tag) != 3:
    raise AssertionError("Tags should be 3 bytes.  Got %r" % tag)
  f.write(tag + '\0')  # NUL terminated


def WriteResult(f, num_em_iters, pij):
  WriteTag(f, 'emi')
  emi = np.array([num_em_iters], np.uint32)
  emi.tofile(f)

  WriteTag(f, 'pij')
  pij.tofile(f)


def DebugSum(num_entries, entry_size, v):
  """Sum the entries as a sanity check."""
  cond_prob = tf.placeholder(tf.float64, shape=(num_entries * entry_size,))
  debug_sum = tf.reduce_sum(cond_prob)
  with tf.Session() as sess:
    s = sess.run(debug_sum, feed_dict={cond_prob: v})
  log('Debug sum: %f', s)


def BuildEmIter(num_entries, entry_size, v):
  # Placeholder for the value from the previous iteration.
  pij_in = tf.placeholder(tf.float64, shape=(entry_size,))

  # split along dimension 0
  # TODO:
  # - make sure this doesn't get run for every EM iteration
  # - investigate using tf.tile() instead?  (this may cost more memory)
  v_split = tf.split(0, num_entries, v)

  z_numerator = [report * pij_in for report in v_split]
  sum_z = [tf.reduce_sum(report) for report in z_numerator]
  z = [z_numerator[i] / sum_z[i] for i in xrange(num_entries)]

  # Concat per-report tensors and reshape.  This is probably inefficient?
  z_concat = tf.concat(0, z)
  z_concat = tf.reshape(z_concat, [num_entries, entry_size])

  # This whole expression represents an EM iteration.  Bind the pij_in
  # placeholder, and get a new estimation of Pij.
  em_iter_expr = tf.reduce_sum(z_concat, 0) / num_entries

  return pij_in, em_iter_expr


def RunEm(pij_in, entry_size, em_iter_expr, max_em_iters, epsilon=1e-6):
  """Run the iterative EM algorithm (using the TensorFlow API).

  Args:
    num_entries: number of matrices (one per report)
    entry_size: total number of cells in each matrix
    v: numpy.ndarray (e.g. 7000 x 8 matrix)
    max_em_iters: maximum number of EM iterations

  Returns:
    pij: numpy.ndarray (e.g. vector of length 8)
  """
  # Initial value is the uniform distribution
  pij = np.ones(entry_size) / entry_size

  i = 0  # visible outside loop

  # Do EM iterations.
  with tf.Session() as sess:
    for i in xrange(max_em_iters):
      print 'PIJ', pij
      new_pij = sess.run(em_iter_expr, feed_dict={pij_in: pij})
      dif = max(abs(new_pij - pij))
      log('EM iteration %d, dif = %e', i, dif)
      pij = new_pij

      if dif < epsilon:
        log('Early EM termination: %e < %e', max_dif, epsilon)
        break

  # If i = 9, then we did 10 iteratinos.
  return i + 1, pij


def sep():
  print '-' * 80


def main(argv):
  input_path = argv[1]
  output_path = argv[2]
  max_em_iters = int(argv[3])

  sep()
  with open(input_path) as f:
    num_entries, entry_size, cond_prob = ReadListOfMatrices(f)

  sep()
  DebugSum(num_entries, entry_size, cond_prob)

  sep()
  pij_in, em_iter_expr = BuildEmIter(num_entries, entry_size, cond_prob)
  num_em_iters, pij = RunEm(pij_in, entry_size, em_iter_expr, max_em_iters)

  sep()
  log('Final Pij: %s', pij)

  with open(output_path, 'wb') as f:
    WriteResult(f, num_em_iters, pij)
  log('Wrote %s', output_path)


if __name__ == '__main__':
  try:
    main(sys.argv)
  except RuntimeError, e:
    print >>sys.stderr, 'FATAL: %s' % e
    sys.exit(1)