diff options
author | Luis Hector Chavez <lhchavez@google.com> | 2018-11-01 20:02:22 -0700 |
---|---|---|
committer | Luis Hector Chavez <lhchavez@google.com> | 2019-03-25 17:02:26 -0700 |
commit | a54812b1dff7d7d66fb914bd6d8c8cdf5b66cedf (patch) | |
tree | 6597ce3f32639d5903d6e5fd9b828fa02d635fb2 | |
parent | b21da7a85afbb8b835fc72c84eaad2b3be87a7e8 (diff) | |
download | minijail-a54812b1dff7d7d66fb914bd6d8c8cdf5b66cedf.tar.gz |
tools/compile_seccomp_policy: Significantly improve BST codegen
This change makes the BST compilation strategy significantly better
using two optimizations:
* Use a Dynamic Programming approach to consider every possible BST, and
choose the one that has the least cost.
* At each step, consider switching between a BST and the linear chain
model. This makes very short models not need to pay the overhead of
the intermediate nodes.
Bug: chromium:856315
Test: ./tools/compiler_unittest.py
Change-Id: Idf90f1be528d760b792c3cd21c76cd1040a9c086
-rw-r--r-- | tools/bpf.py | 20 | ||||
-rw-r--r-- | tools/compiler.py | 204 | ||||
-rwxr-xr-x | tools/compiler_unittest.py | 35 |
3 files changed, 151 insertions, 108 deletions
diff --git a/tools/bpf.py b/tools/bpf.py index e89e93f..bd7007e 100644 --- a/tools/bpf.py +++ b/tools/bpf.py @@ -651,3 +651,23 @@ class FlatteningVisitor: self._instructions = instructions + self._instructions self._offsets[id(block)] = -len(self._instructions) return + + +class ArgFilterForwardingVisitor: + """A visitor that forwards visitation to all arg filters.""" + + def __init__(self, visitor): + self.visitor = visitor + + def visit(self, block): + # All arg filters are BasicBlocks. + if not isinstance(block, BasicBlock): + return + # But the ALLOW, KILL_PROCESS, TRAP, etc. actions are too and we don't + # want to visit them just yet. + if (isinstance(block, KillProcess) or isinstance(block, KillThread) + or isinstance(block, Trap) or isinstance(block, ReturnErrno) + or isinstance(block, Trace) or isinstance(block, Log) + or isinstance(block, Allow)): + return + block.accept(self.visitor) diff --git a/tools/compiler.py b/tools/compiler.py index 73053d2..9895df8 100644 --- a/tools/compiler.py +++ b/tools/compiler.py @@ -103,52 +103,63 @@ def _convert_to_ranges(entries): def _compile_single_range(entry, accept_action, reject_action, - visitor, lower_bound=0, upper_bound=1e99): action = accept_action if entry.filter: - entry.filter.accept(visitor) action = entry.filter if entry.numbers[1] - entry.numbers[0] == 1: # Single syscall. # Accept if |X == nr|. - return bpf.SyscallEntry( - entry.numbers[0], action, reject_action, op=bpf.BPF_JEQ) + return (1, + bpf.SyscallEntry( + entry.numbers[0], action, reject_action, op=bpf.BPF_JEQ)) elif entry.numbers[0] == lower_bound: # Syscall range aligned with the lower bound. # Accept if |X < nr[1]|. - return bpf.SyscallEntry( - entry.numbers[1], reject_action, action, op=bpf.BPF_JGE) + return (1, + bpf.SyscallEntry( + entry.numbers[1], reject_action, action, op=bpf.BPF_JGE)) elif entry.numbers[1] == upper_bound: # Syscall range aligned with the upper bound. # Accept if |X >= nr[0]|. - return bpf.SyscallEntry( - entry.numbers[0], action, reject_action, op=bpf.BPF_JGE) + return (1, + bpf.SyscallEntry( + entry.numbers[0], action, reject_action, op=bpf.BPF_JGE)) # Syscall range in the middle. # Accept if |nr[0] <= X < nr[1]|. upper_entry = bpf.SyscallEntry( entry.numbers[1], reject_action, action, op=bpf.BPF_JGE) - return bpf.SyscallEntry( - entry.numbers[0], upper_entry, reject_action, op=bpf.BPF_JGE) + return (2, + bpf.SyscallEntry( + entry.numbers[0], upper_entry, reject_action, op=bpf.BPF_JGE)) -def _compile_entries_linear(entries, accept_action, reject_action, visitor): - # Compiles the list of entries into a simple linear list of comparisons. In +def _compile_ranges_linear(ranges, accept_action, reject_action): + # Compiles the list of ranges into a simple linear list of comparisons. In # order to make the generated code a bit more efficient, we sort the - # entries by frequency, so that the most frequently-called syscalls appear + # ranges by frequency, so that the most frequently-called syscalls appear # earlier in the chain. + cost = 0 + accumulated_frequencies = 0 next_action = reject_action - ranges = sorted(_convert_to_ranges(entries), key=lambda r: -r.frequency) - for entry in ranges[::-1]: - next_action = _compile_single_range(entry, accept_action, next_action, - visitor) - return next_action + for entry in sorted(ranges, key=lambda r: r.frequency): + current_cost, next_action = _compile_single_range( + entry, accept_action, next_action) + accumulated_frequencies += entry.frequency + cost += accumulated_frequencies * current_cost + return (cost, next_action) -def _compile_entries_bst(entries, accept_action, reject_action, visitor): +def _compile_entries_linear(entries, accept_action, reject_action): + return _compile_ranges_linear( + _convert_to_ranges(entries), accept_action, reject_action)[1] + + +def _compile_entries_bst(entries, accept_action, reject_action): # Instead of generating a linear list of comparisons, this method generates - # a binary search tree. + # a binary search tree, where some of the leaves can be linear chains of + # comparisons. # # Even though we are going to perform a binary search over the syscall # number, we would still like to rotate some of the internal nodes of the @@ -156,17 +167,24 @@ def _compile_entries_bst(entries, accept_action, reject_action, visitor): # more cheaply (i.e. fewer internal nodes need to be traversed to reach # them). # - # The overall idea then is to, at any step, instead of naively partitioning - # the list of syscalls by the midpoint of the interval, we choose a - # midpoint that minimizes the difference of the sum of all frequencies - # between the left and right subtrees. For that, we need to sort the - # entries by syscall number and keep track of the accumulated frequency of - # all entries prior to the current one so that we can compue the midpoint - # efficiently. + # This uses Dynamic Programming to generate all possible BSTs efficiently + # (in O(n^3)) so that we can get the absolute minimum-cost tree that matches + # all syscall entries. It does so by considering all of the O(n^2) possible + # sub-intervals, and for each one of those try all of the O(n) partitions of + # that sub-interval. At each step, it considers putting the remaining + # entries in a linear comparison chain as well as another BST, and chooses + # the option that minimizes the total overall cost. # - # TODO(lhchavez): There is one further possible optimization, which is to - # hoist any syscalls that are more frequent than all other syscalls in the - # BST combined into a linear chain before entering the BST. + # Between every pair of non-contiguous allowed syscalls, there are two + # locally optimal options as to where to set the partition for the + # subsequent ranges: aligned to the end of the left subrange or to the + # beginning of the right subrange. The fact that these two options have + # slightly different costs, combined with the possibility of a subtree to + # use the linear chain strategy (which has a completely different cost + # model), causes the target cost function that we are trying to optimize to + # not be unimodal / convex. This unfortunately means that more clever + # techniques like using ternary search (which would reduce the overall + # complexity to O(n^2 log n)) do not work in all cases. ranges = list(_convert_to_ranges(entries)) accumulated = 0 @@ -174,67 +192,66 @@ def _compile_entries_bst(entries, accept_action, reject_action, visitor): accumulated += entry.frequency entry.accumulated = accumulated - # Recursively create the internal nodes. - def _generate_syscall_bst(ranges, lower_bound=0, upper_bound=2**64 - 1): - assert ranges - if len(ranges) == 1: - # This is a single syscall entry range, but the interval we are - # currently considering contains other syscalls that we want to - # reject. So instead of an internal node, create one or more leaf - # nodes that check the range. - assert lower_bound < upper_bound - return _compile_single_range(ranges[0], accept_action, - reject_action, visitor, lower_bound, - upper_bound) - - # Find the midpoint that minimizes the difference between accumulated - # costs in the left and right subtrees. - previous_accumulated = ranges[0].accumulated - ranges[0].frequency - last_accumulated = ranges[-1].accumulated - previous_accumulated - best = (1e99, -1) - for i, entry in enumerate(ranges): - if not i: - continue - left_accumulated = entry.accumulated - previous_accumulated - right_accumulated = last_accumulated - left_accumulated - best = min(best, (abs(left_accumulated - right_accumulated), i)) - midpoint = best[1] - assert midpoint >= 1, best - - cutoff_bound = ranges[midpoint].numbers[0] - - # Now we build the right and left subtrees independently. If any of the - # subtrees consist of a single entry _and_ the bounds are tight around - # that entry (that is, the bounds contain _only_ the syscall we are - # going to consider), we can avoid emitting a leaf node and instead - # have the comparison jump directly into the action that would be taken - # by the entry. - if (cutoff_bound, upper_bound) == ranges[midpoint].numbers: - if ranges[midpoint].filter: - ranges[midpoint].filter.accept(visitor) - right_subtree = ranges[midpoint].filter - else: - right_subtree = accept_action - else: - right_subtree = _generate_syscall_bst(ranges[midpoint:], - cutoff_bound, upper_bound) - - if (lower_bound, cutoff_bound) == ranges[midpoint - 1].numbers: - if ranges[midpoint - 1].filter: - ranges[midpoint - 1].filter.accept(visitor) - left_subtree = ranges[midpoint - 1].filter - else: - left_subtree = accept_action - else: - left_subtree = _generate_syscall_bst(ranges[:midpoint], - lower_bound, cutoff_bound) - - # Finally, now that both subtrees have been generated, we can create - # the internal node of the binary search tree. - return bpf.SyscallEntry( - cutoff_bound, right_subtree, left_subtree, op=bpf.BPF_JGE) - - return _generate_syscall_bst(ranges) + # Memoization cache to build the DP table top-down, which is easier to + # understand. + memoized_costs = {} + + def _generate_syscall_bst(ranges, indices, bounds=(0, 2**64 - 1)): + assert bounds[0] <= ranges[indices[0]].numbers[0], (indices, bounds) + assert ranges[indices[1] - 1].numbers[1] <= bounds[1], (indices, + bounds) + + if bounds in memoized_costs: + return memoized_costs[bounds] + if indices[1] - indices[0] == 1: + if bounds == ranges[indices[0]].numbers: + # If bounds are tight around the syscall, it costs nothing. + memoized_costs[bounds] = (0, ranges[indices[0]].filter + or accept_action) + return memoized_costs[bounds] + result = _compile_single_range(ranges[indices[0]], accept_action, + reject_action) + memoized_costs[bounds] = (result[0] * ranges[indices[0]].frequency, + result[1]) + return memoized_costs[bounds] + + # Try the linear model first and use that as the best estimate so far. + best_cost = _compile_ranges_linear(ranges[slice(*indices)], + accept_action, reject_action) + + # Now recursively go through all possible partitions of the interval + # currently being considered. + previous_accumulated = ranges[indices[0]].accumulated - ranges[indices[0]].frequency + bst_comparison_cost = ( + ranges[indices[1] - 1].accumulated - previous_accumulated) + for i, entry in enumerate(ranges[slice(*indices)]): + candidates = [entry.numbers[0]] + if i: + candidates.append(ranges[i - 1 + indices[0]].numbers[1]) + for cutoff_bound in candidates: + if not bounds[0] < cutoff_bound < bounds[1]: + continue + if not indices[0] < i + indices[0] < indices[1]: + continue + left_subtree = _generate_syscall_bst( + ranges, (indices[0], i + indices[0]), + (bounds[0], cutoff_bound)) + right_subtree = _generate_syscall_bst( + ranges, (i + indices[0], indices[1]), + (cutoff_bound, bounds[1])) + best_cost = min( + best_cost, + (bst_comparison_cost + left_subtree[0] + right_subtree[0], + bpf.SyscallEntry( + cutoff_bound, + right_subtree[1], + left_subtree[1], + op=bpf.BPF_JGE))) + + memoized_costs[bounds] = best_cost + return memoized_costs[bounds] + + return _generate_syscall_bst(ranges, (0, len(ranges)))[1] class PolicyCompiler: @@ -268,10 +285,11 @@ class PolicyCompiler: if entries: if optimization_strategy == OptimizationStrategy.BST: next_action = _compile_entries_bst(entries, accept_action, - reject_action, visitor) + reject_action) else: next_action = _compile_entries_linear(entries, accept_action, - reject_action, visitor) + reject_action) + next_action.accept(bpf.ArgFilterForwardingVisitor(visitor)) reject_action.accept(visitor) accept_action.accept(visitor) bpf.ValidateArch(next_action).accept(visitor) diff --git a/tools/compiler_unittest.py b/tools/compiler_unittest.py index cfa2b8d..ae00d7f 100755 --- a/tools/compiler_unittest.py +++ b/tools/compiler_unittest.py @@ -292,8 +292,8 @@ class CompileFileTests(unittest.TestCase): outf.write(contents) return path - def test_compile_linear(self): - """Reject empty / malformed lines.""" + def test_compile(self): + """Ensure compilation works with all strategies.""" self._write_file( 'test.frequency', """ read: 1 @@ -331,18 +331,23 @@ class CompileFileTests(unittest.TestCase): close: 1 """) - program = self.compiler.compile_file( - path, - optimization_strategy=compiler.OptimizationStrategy.BST, - kill_action=bpf.KillProcess()) - # BST for very few syscalls does not make a lot of sense and does - # introduce some overhead, so there will be no checking for cost. - self.assertEqual( - bpf.simulate(program.instructions, self.arch.arch_nr, - self.arch.syscalls['read'], 0)[1], 'ALLOW') - self.assertEqual( - bpf.simulate(program.instructions, self.arch.arch_nr, - self.arch.syscalls['close'], 0)[1], 'ALLOW') + for strategy in list(compiler.OptimizationStrategy): + program = self.compiler.compile_file( + path, + optimization_strategy=strategy, + kill_action=bpf.KillProcess()) + self.assertGreater( + bpf.simulate(program.instructions, self.arch.arch_nr, + self.arch.syscalls['read'], 0)[0], + bpf.simulate(program.instructions, self.arch.arch_nr, + self.arch.syscalls['close'], 0)[0], + ) + self.assertEqual( + bpf.simulate(program.instructions, self.arch.arch_nr, + self.arch.syscalls['read'], 0)[1], 'ALLOW') + self.assertEqual( + bpf.simulate(program.instructions, self.arch.arch_nr, + self.arch.syscalls['close'], 0)[1], 'ALLOW') def test_compile_empty_file(self): """Accept empty files.""" @@ -362,7 +367,7 @@ class CompileFileTests(unittest.TestCase): def test_compile_simulate(self): """Ensure policy reflects script by testing some random scripts.""" - iterations = 10 + iterations = 5 for i in range(iterations): num_entries = len(self.arch.syscalls) * (i + 1) // iterations syscalls = dict( |