aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuis Hector Chavez <lhchavez@google.com>2018-11-01 20:02:22 -0700
committerLuis Hector Chavez <lhchavez@google.com>2019-03-25 17:02:26 -0700
commita54812b1dff7d7d66fb914bd6d8c8cdf5b66cedf (patch)
tree6597ce3f32639d5903d6e5fd9b828fa02d635fb2
parentb21da7a85afbb8b835fc72c84eaad2b3be87a7e8 (diff)
downloadminijail-a54812b1dff7d7d66fb914bd6d8c8cdf5b66cedf.tar.gz
tools/compile_seccomp_policy: Significantly improve BST codegen
This change makes the BST compilation strategy significantly better using two optimizations: * Use a Dynamic Programming approach to consider every possible BST, and choose the one that has the least cost. * At each step, consider switching between a BST and the linear chain model. This makes very short models not need to pay the overhead of the intermediate nodes. Bug: chromium:856315 Test: ./tools/compiler_unittest.py Change-Id: Idf90f1be528d760b792c3cd21c76cd1040a9c086
-rw-r--r--tools/bpf.py20
-rw-r--r--tools/compiler.py204
-rwxr-xr-xtools/compiler_unittest.py35
3 files changed, 151 insertions, 108 deletions
diff --git a/tools/bpf.py b/tools/bpf.py
index e89e93f..bd7007e 100644
--- a/tools/bpf.py
+++ b/tools/bpf.py
@@ -651,3 +651,23 @@ class FlatteningVisitor:
self._instructions = instructions + self._instructions
self._offsets[id(block)] = -len(self._instructions)
return
+
+
+class ArgFilterForwardingVisitor:
+ """A visitor that forwards visitation to all arg filters."""
+
+ def __init__(self, visitor):
+ self.visitor = visitor
+
+ def visit(self, block):
+ # All arg filters are BasicBlocks.
+ if not isinstance(block, BasicBlock):
+ return
+ # But the ALLOW, KILL_PROCESS, TRAP, etc. actions are too and we don't
+ # want to visit them just yet.
+ if (isinstance(block, KillProcess) or isinstance(block, KillThread)
+ or isinstance(block, Trap) or isinstance(block, ReturnErrno)
+ or isinstance(block, Trace) or isinstance(block, Log)
+ or isinstance(block, Allow)):
+ return
+ block.accept(self.visitor)
diff --git a/tools/compiler.py b/tools/compiler.py
index 73053d2..9895df8 100644
--- a/tools/compiler.py
+++ b/tools/compiler.py
@@ -103,52 +103,63 @@ def _convert_to_ranges(entries):
def _compile_single_range(entry,
accept_action,
reject_action,
- visitor,
lower_bound=0,
upper_bound=1e99):
action = accept_action
if entry.filter:
- entry.filter.accept(visitor)
action = entry.filter
if entry.numbers[1] - entry.numbers[0] == 1:
# Single syscall.
# Accept if |X == nr|.
- return bpf.SyscallEntry(
- entry.numbers[0], action, reject_action, op=bpf.BPF_JEQ)
+ return (1,
+ bpf.SyscallEntry(
+ entry.numbers[0], action, reject_action, op=bpf.BPF_JEQ))
elif entry.numbers[0] == lower_bound:
# Syscall range aligned with the lower bound.
# Accept if |X < nr[1]|.
- return bpf.SyscallEntry(
- entry.numbers[1], reject_action, action, op=bpf.BPF_JGE)
+ return (1,
+ bpf.SyscallEntry(
+ entry.numbers[1], reject_action, action, op=bpf.BPF_JGE))
elif entry.numbers[1] == upper_bound:
# Syscall range aligned with the upper bound.
# Accept if |X >= nr[0]|.
- return bpf.SyscallEntry(
- entry.numbers[0], action, reject_action, op=bpf.BPF_JGE)
+ return (1,
+ bpf.SyscallEntry(
+ entry.numbers[0], action, reject_action, op=bpf.BPF_JGE))
# Syscall range in the middle.
# Accept if |nr[0] <= X < nr[1]|.
upper_entry = bpf.SyscallEntry(
entry.numbers[1], reject_action, action, op=bpf.BPF_JGE)
- return bpf.SyscallEntry(
- entry.numbers[0], upper_entry, reject_action, op=bpf.BPF_JGE)
+ return (2,
+ bpf.SyscallEntry(
+ entry.numbers[0], upper_entry, reject_action, op=bpf.BPF_JGE))
-def _compile_entries_linear(entries, accept_action, reject_action, visitor):
- # Compiles the list of entries into a simple linear list of comparisons. In
+def _compile_ranges_linear(ranges, accept_action, reject_action):
+ # Compiles the list of ranges into a simple linear list of comparisons. In
# order to make the generated code a bit more efficient, we sort the
- # entries by frequency, so that the most frequently-called syscalls appear
+ # ranges by frequency, so that the most frequently-called syscalls appear
# earlier in the chain.
+ cost = 0
+ accumulated_frequencies = 0
next_action = reject_action
- ranges = sorted(_convert_to_ranges(entries), key=lambda r: -r.frequency)
- for entry in ranges[::-1]:
- next_action = _compile_single_range(entry, accept_action, next_action,
- visitor)
- return next_action
+ for entry in sorted(ranges, key=lambda r: r.frequency):
+ current_cost, next_action = _compile_single_range(
+ entry, accept_action, next_action)
+ accumulated_frequencies += entry.frequency
+ cost += accumulated_frequencies * current_cost
+ return (cost, next_action)
-def _compile_entries_bst(entries, accept_action, reject_action, visitor):
+def _compile_entries_linear(entries, accept_action, reject_action):
+ return _compile_ranges_linear(
+ _convert_to_ranges(entries), accept_action, reject_action)[1]
+
+
+def _compile_entries_bst(entries, accept_action, reject_action):
# Instead of generating a linear list of comparisons, this method generates
- # a binary search tree.
+ # a binary search tree, where some of the leaves can be linear chains of
+ # comparisons.
#
# Even though we are going to perform a binary search over the syscall
# number, we would still like to rotate some of the internal nodes of the
@@ -156,17 +167,24 @@ def _compile_entries_bst(entries, accept_action, reject_action, visitor):
# more cheaply (i.e. fewer internal nodes need to be traversed to reach
# them).
#
- # The overall idea then is to, at any step, instead of naively partitioning
- # the list of syscalls by the midpoint of the interval, we choose a
- # midpoint that minimizes the difference of the sum of all frequencies
- # between the left and right subtrees. For that, we need to sort the
- # entries by syscall number and keep track of the accumulated frequency of
- # all entries prior to the current one so that we can compue the midpoint
- # efficiently.
+ # This uses Dynamic Programming to generate all possible BSTs efficiently
+ # (in O(n^3)) so that we can get the absolute minimum-cost tree that matches
+ # all syscall entries. It does so by considering all of the O(n^2) possible
+ # sub-intervals, and for each one of those try all of the O(n) partitions of
+ # that sub-interval. At each step, it considers putting the remaining
+ # entries in a linear comparison chain as well as another BST, and chooses
+ # the option that minimizes the total overall cost.
#
- # TODO(lhchavez): There is one further possible optimization, which is to
- # hoist any syscalls that are more frequent than all other syscalls in the
- # BST combined into a linear chain before entering the BST.
+ # Between every pair of non-contiguous allowed syscalls, there are two
+ # locally optimal options as to where to set the partition for the
+ # subsequent ranges: aligned to the end of the left subrange or to the
+ # beginning of the right subrange. The fact that these two options have
+ # slightly different costs, combined with the possibility of a subtree to
+ # use the linear chain strategy (which has a completely different cost
+ # model), causes the target cost function that we are trying to optimize to
+ # not be unimodal / convex. This unfortunately means that more clever
+ # techniques like using ternary search (which would reduce the overall
+ # complexity to O(n^2 log n)) do not work in all cases.
ranges = list(_convert_to_ranges(entries))
accumulated = 0
@@ -174,67 +192,66 @@ def _compile_entries_bst(entries, accept_action, reject_action, visitor):
accumulated += entry.frequency
entry.accumulated = accumulated
- # Recursively create the internal nodes.
- def _generate_syscall_bst(ranges, lower_bound=0, upper_bound=2**64 - 1):
- assert ranges
- if len(ranges) == 1:
- # This is a single syscall entry range, but the interval we are
- # currently considering contains other syscalls that we want to
- # reject. So instead of an internal node, create one or more leaf
- # nodes that check the range.
- assert lower_bound < upper_bound
- return _compile_single_range(ranges[0], accept_action,
- reject_action, visitor, lower_bound,
- upper_bound)
-
- # Find the midpoint that minimizes the difference between accumulated
- # costs in the left and right subtrees.
- previous_accumulated = ranges[0].accumulated - ranges[0].frequency
- last_accumulated = ranges[-1].accumulated - previous_accumulated
- best = (1e99, -1)
- for i, entry in enumerate(ranges):
- if not i:
- continue
- left_accumulated = entry.accumulated - previous_accumulated
- right_accumulated = last_accumulated - left_accumulated
- best = min(best, (abs(left_accumulated - right_accumulated), i))
- midpoint = best[1]
- assert midpoint >= 1, best
-
- cutoff_bound = ranges[midpoint].numbers[0]
-
- # Now we build the right and left subtrees independently. If any of the
- # subtrees consist of a single entry _and_ the bounds are tight around
- # that entry (that is, the bounds contain _only_ the syscall we are
- # going to consider), we can avoid emitting a leaf node and instead
- # have the comparison jump directly into the action that would be taken
- # by the entry.
- if (cutoff_bound, upper_bound) == ranges[midpoint].numbers:
- if ranges[midpoint].filter:
- ranges[midpoint].filter.accept(visitor)
- right_subtree = ranges[midpoint].filter
- else:
- right_subtree = accept_action
- else:
- right_subtree = _generate_syscall_bst(ranges[midpoint:],
- cutoff_bound, upper_bound)
-
- if (lower_bound, cutoff_bound) == ranges[midpoint - 1].numbers:
- if ranges[midpoint - 1].filter:
- ranges[midpoint - 1].filter.accept(visitor)
- left_subtree = ranges[midpoint - 1].filter
- else:
- left_subtree = accept_action
- else:
- left_subtree = _generate_syscall_bst(ranges[:midpoint],
- lower_bound, cutoff_bound)
-
- # Finally, now that both subtrees have been generated, we can create
- # the internal node of the binary search tree.
- return bpf.SyscallEntry(
- cutoff_bound, right_subtree, left_subtree, op=bpf.BPF_JGE)
-
- return _generate_syscall_bst(ranges)
+ # Memoization cache to build the DP table top-down, which is easier to
+ # understand.
+ memoized_costs = {}
+
+ def _generate_syscall_bst(ranges, indices, bounds=(0, 2**64 - 1)):
+ assert bounds[0] <= ranges[indices[0]].numbers[0], (indices, bounds)
+ assert ranges[indices[1] - 1].numbers[1] <= bounds[1], (indices,
+ bounds)
+
+ if bounds in memoized_costs:
+ return memoized_costs[bounds]
+ if indices[1] - indices[0] == 1:
+ if bounds == ranges[indices[0]].numbers:
+ # If bounds are tight around the syscall, it costs nothing.
+ memoized_costs[bounds] = (0, ranges[indices[0]].filter
+ or accept_action)
+ return memoized_costs[bounds]
+ result = _compile_single_range(ranges[indices[0]], accept_action,
+ reject_action)
+ memoized_costs[bounds] = (result[0] * ranges[indices[0]].frequency,
+ result[1])
+ return memoized_costs[bounds]
+
+ # Try the linear model first and use that as the best estimate so far.
+ best_cost = _compile_ranges_linear(ranges[slice(*indices)],
+ accept_action, reject_action)
+
+ # Now recursively go through all possible partitions of the interval
+ # currently being considered.
+ previous_accumulated = ranges[indices[0]].accumulated - ranges[indices[0]].frequency
+ bst_comparison_cost = (
+ ranges[indices[1] - 1].accumulated - previous_accumulated)
+ for i, entry in enumerate(ranges[slice(*indices)]):
+ candidates = [entry.numbers[0]]
+ if i:
+ candidates.append(ranges[i - 1 + indices[0]].numbers[1])
+ for cutoff_bound in candidates:
+ if not bounds[0] < cutoff_bound < bounds[1]:
+ continue
+ if not indices[0] < i + indices[0] < indices[1]:
+ continue
+ left_subtree = _generate_syscall_bst(
+ ranges, (indices[0], i + indices[0]),
+ (bounds[0], cutoff_bound))
+ right_subtree = _generate_syscall_bst(
+ ranges, (i + indices[0], indices[1]),
+ (cutoff_bound, bounds[1]))
+ best_cost = min(
+ best_cost,
+ (bst_comparison_cost + left_subtree[0] + right_subtree[0],
+ bpf.SyscallEntry(
+ cutoff_bound,
+ right_subtree[1],
+ left_subtree[1],
+ op=bpf.BPF_JGE)))
+
+ memoized_costs[bounds] = best_cost
+ return memoized_costs[bounds]
+
+ return _generate_syscall_bst(ranges, (0, len(ranges)))[1]
class PolicyCompiler:
@@ -268,10 +285,11 @@ class PolicyCompiler:
if entries:
if optimization_strategy == OptimizationStrategy.BST:
next_action = _compile_entries_bst(entries, accept_action,
- reject_action, visitor)
+ reject_action)
else:
next_action = _compile_entries_linear(entries, accept_action,
- reject_action, visitor)
+ reject_action)
+ next_action.accept(bpf.ArgFilterForwardingVisitor(visitor))
reject_action.accept(visitor)
accept_action.accept(visitor)
bpf.ValidateArch(next_action).accept(visitor)
diff --git a/tools/compiler_unittest.py b/tools/compiler_unittest.py
index cfa2b8d..ae00d7f 100755
--- a/tools/compiler_unittest.py
+++ b/tools/compiler_unittest.py
@@ -292,8 +292,8 @@ class CompileFileTests(unittest.TestCase):
outf.write(contents)
return path
- def test_compile_linear(self):
- """Reject empty / malformed lines."""
+ def test_compile(self):
+ """Ensure compilation works with all strategies."""
self._write_file(
'test.frequency', """
read: 1
@@ -331,18 +331,23 @@ class CompileFileTests(unittest.TestCase):
close: 1
""")
- program = self.compiler.compile_file(
- path,
- optimization_strategy=compiler.OptimizationStrategy.BST,
- kill_action=bpf.KillProcess())
- # BST for very few syscalls does not make a lot of sense and does
- # introduce some overhead, so there will be no checking for cost.
- self.assertEqual(
- bpf.simulate(program.instructions, self.arch.arch_nr,
- self.arch.syscalls['read'], 0)[1], 'ALLOW')
- self.assertEqual(
- bpf.simulate(program.instructions, self.arch.arch_nr,
- self.arch.syscalls['close'], 0)[1], 'ALLOW')
+ for strategy in list(compiler.OptimizationStrategy):
+ program = self.compiler.compile_file(
+ path,
+ optimization_strategy=strategy,
+ kill_action=bpf.KillProcess())
+ self.assertGreater(
+ bpf.simulate(program.instructions, self.arch.arch_nr,
+ self.arch.syscalls['read'], 0)[0],
+ bpf.simulate(program.instructions, self.arch.arch_nr,
+ self.arch.syscalls['close'], 0)[0],
+ )
+ self.assertEqual(
+ bpf.simulate(program.instructions, self.arch.arch_nr,
+ self.arch.syscalls['read'], 0)[1], 'ALLOW')
+ self.assertEqual(
+ bpf.simulate(program.instructions, self.arch.arch_nr,
+ self.arch.syscalls['close'], 0)[1], 'ALLOW')
def test_compile_empty_file(self):
"""Accept empty files."""
@@ -362,7 +367,7 @@ class CompileFileTests(unittest.TestCase):
def test_compile_simulate(self):
"""Ensure policy reflects script by testing some random scripts."""
- iterations = 10
+ iterations = 5
for i in range(iterations):
num_entries = len(self.arch.syscalls) * (i + 1) // iterations
syscalls = dict(