diff options
Diffstat (limited to 'tests/gen_counts.R')
-rwxr-xr-x | tests/gen_counts.R | 213 |
1 files changed, 213 insertions, 0 deletions
diff --git a/tests/gen_counts.R b/tests/gen_counts.R new file mode 100755 index 0000000..769677c --- /dev/null +++ b/tests/gen_counts.R @@ -0,0 +1,213 @@ +#!/usr/bin/env Rscript +# +# Copyright 2015 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +source('analysis/R/read_input.R') + +RandomPartition <- function(total, weights) { + # Outputs a random partition according to a specified distribution + # Args: + # total - number of samples + # weights - weights that are proportional to the probability density + # function of the target distribution + # Returns: + # a histogram sampled according to the pdf + # Example: + # > RandomPartition(100, c(3, 2, 1, 0, 1)) + # [1] 47 24 15 0 14 + if (any(weights < 0)) + stop("Probabilities cannot be negative") + + if (sum(weights) == 0) + stop("Probabilities cannot sum up to 0") + + bins <- length(weights) + result <- rep(0, bins) + + # idiomatic way: + # rnd_list <- sample(strs, total, replace = TRUE, weights) + # apply(as.array(strs), 1, function(x) length(rnd_list[rnd_list == x])) + # + # The following is much faster for larger totals. We can replace a loop with + # (tail) recusion, but R chokes with the recursion depth > 850. + + w <- sum(weights) + + for (i in 1:bins) + if (total > 0) { # if total == 0, nothing else to do + # invariant: w = sum(weights[i:bins]) + # rather than computing sum every time leading to quadratic time, keep + # updating it + + # The probability p is clamped to [0, 1] to avoid under/overflow errors. + p <- min(max(weights[i] / w, 0), 1) + # draw the number of balls falling into the current bin + rnd_draw <- rbinom(n = 1, size = total, prob = p) + result[i] <- rnd_draw # push rnd_draw balls from total to result[i] + total <- total - rnd_draw + w <- w - weights[i] + } + + names(result) <- names(weights) + + return(result) +} + +GenerateCounts <- function(params, true_map, partition, reports_per_client) { + # Fast simulation of the marginal table for RAPPOR reports + # Args: + # params - parameters of the RAPPOR reporting process + # true_map - hashed true inputs + # partition - allocation of clients between true values + # reports_per_client - number of reports (IRRs) per client + if (nrow(true_map) != (params$m * params$k)) { + stop(cat("Map does not match the params file!", + "mk =", params$m * params$k, + "nrow(map):", nrow(true_map), + sep = " ")) + } + + # For each reporting type computes its allocation to cohorts. + # Output is an m x strs matrix. + cohorts <- as.matrix( + apply(as.data.frame(partition), 1, + function(count) RandomPartition(count, rep(1, params$m)))) + + # Expands to (m x k) x strs matrix, where each element (corresponding to the + # bit in the aggregate Bloom filter) is repeated k times. + expanded <- apply(cohorts, 2, function(vec) rep(vec, each = params$k)) + + # For each bit, the number of clients reporting this bit: + clients_per_bit <- rep(apply(cohorts, 1, sum), each = params$k) + + # Computes the true number of bits set to one BEFORE PRR. + true_ones <- apply(expanded * true_map, 1, sum) + + ones_in_prr <- + unlist(lapply(true_ones, + function(x) rbinom(n = 1, size = x, prob = 1 - params$f / 2))) + + unlist(lapply(clients_per_bit - true_ones, # clients where the bit is 0 + function(x) rbinom(n = 1, size = x, prob = params$f / 2))) + + # Number of IRRs where each bit is reported (either as 0 or as 1) + reports_per_bit <- clients_per_bit * reports_per_client + + ones_before_irr <- ones_in_prr * reports_per_client + + ones_after_irr <- + unlist(lapply(ones_before_irr, + function(x) rbinom(n = 1, size = x, prob = params$q))) + + unlist(lapply(reports_per_bit - ones_before_irr, + function(x) rbinom(n = 1, size = x, prob = params$p))) + + counts <- cbind(apply(cohorts, 1, sum) * reports_per_client, + matrix(ones_after_irr, nrow = params$m, ncol = params$k, byrow = TRUE)) + + if(any(is.na(counts))) + stop("Failed to generate bit counts. Likely due to integer overflow.") + + counts +} + +ComputePdf <- function(distr, range) { + # Outputs discrete probability density function for a given distribution + + # These are the five distributions in gen_sim_input.py + if (distr == 'exp') { + pdf <- dexp(1:range, rate = 5 / range) + } else if (distr == 'gauss') { + half <- range / 2 + left <- -half + 1 + pdf <- dnorm(left : half, sd = range / 6) + } else if (distr == 'unif') { + # e.g. for N = 4, weights are [0.25, 0.25, 0.25, 0.25] + pdf <- dunif(1:range, max = range) + } else if (distr == 'zipf1') { + # Since the distrubition defined over a finite set, we allow the parameter + # of the Zipf distribution to be 1. + pdf <- sapply(1:range, function(x) 1 / x) + } else if (distr == 'zipf1.5') { + pdf <- sapply(1:range, function(x) 1 / x^1.5) + } + else { + stop(sprintf("Invalid distribution '%s'", distr)) + } + + pdf <- pdf / sum(pdf) # normalize + + pdf +} + +# Usage: +# +# $ ./gen_counts.R exp 10000 1 foo_params.csv foo_true_map.csv foo +# +# Inputs: +# distribution name +# number of clients +# reports per client +# parameters file +# map file +# prefix for output files +# Outputs: +# foo_counts.csv +# foo_hist.csv +# +# Warning: the number of reports in any cohort must be less than +# .Machine$integer.max + +main <- function(argv) { + distr <- argv[[1]] + num_clients <- as.integer(argv[[2]]) + reports_per_client <- as.integer(argv[[3]]) + params_file <- argv[[4]] + true_map_file <- argv[[5]] + out_prefix <- argv[[6]] + + params <- ReadParameterFile(params_file) + + true_map <- ReadMapFile(true_map_file, params) + + num_unique_values <- length(true_map$strs) + + pdf <- ComputePdf(distr, num_unique_values) + + # Computes the number of clients reporting each string + # according to the pre-specified distribution. + partition <- RandomPartition(num_clients, pdf) + + # Histogram + true_hist <- data.frame(string = true_map$strs, count = partition) + + counts <- GenerateCounts(params, true_map$map, partition, reports_per_client) + + # Now create a CSV file + + # Opposite of ReadCountsFile in read_input.R + # http://stackoverflow.com/questions/6750546/export-csv-without-col-names + counts_path <- paste0(out_prefix, '_counts.csv') + write.table(counts, file = counts_path, + row.names = FALSE, col.names = FALSE, sep = ',') + cat(sprintf('Wrote %s\n', counts_path)) + + # TODO: Don't write strings that appear 0 times? + hist_path <- paste0(out_prefix, '_hist.csv') + write.csv(true_hist, file = hist_path, row.names = FALSE) + cat(sprintf('Wrote %s\n', hist_path)) +} + +if (length(sys.frames()) == 0) { + main(commandArgs(TRUE)) +} |