diff options
author | Luis Hector Chavez <lhchavez@google.com> | 2019-03-27 11:14:24 -0700 |
---|---|---|
committer | android-build-merger <android-build-merger@google.com> | 2019-03-27 11:14:24 -0700 |
commit | b9f69fb285422eefe2e692b5c1a168b3d06055f7 (patch) | |
tree | cee5ad9c88f68df3ceedd347910a052dcf48a117 | |
parent | 1b4ea70c5f0e6e2c5a2204033cb4b3f2c724190f (diff) | |
parent | 31f7fed79c67fd03d7c925a41fa7a14d7410e9a1 (diff) | |
download | minijail-b9f69fb285422eefe2e692b5c1a168b3d06055f7.tar.gz |
tools/compile_seccomp_policy: Optimize AST visitation am: 3b41ae3db9
am: 31f7fed79c
Change-Id: I29747c5babe7f2e445069835683f3641169810d0
-rw-r--r-- | tools/bpf.py | 76 | ||||
-rwxr-xr-x | tools/compiler_unittest.py | 34 |
2 files changed, 82 insertions, 28 deletions
diff --git a/tools/bpf.py b/tools/bpf.py index bd7007e..75db502 100644 --- a/tools/bpf.py +++ b/tools/bpf.py @@ -172,6 +172,8 @@ class BasicBlock(AbstractBlock): self._instructions = instructions def accept(self, visitor): + if visitor.visited(self): + return visitor.visit(self) @property @@ -251,6 +253,8 @@ class ValidateArch(AbstractBlock): self.next_block = next_block def accept(self, visitor): + if visitor.visited(self): + return self.next_block.accept(visitor) visitor.visit(self) @@ -274,6 +278,8 @@ class SyscallEntry(AbstractBlock): return False def accept(self, visitor): + if visitor.visited(self): + return self.jt.accept(visitor) self.jf.accept(visitor) visitor.visit(self) @@ -299,6 +305,8 @@ class WideAtom(AbstractBlock): self.jf = jf def accept(self, visitor): + if visitor.visited(self): + return self.jt.accept(visitor) self.jf.accept(visitor) visitor.visit(self) @@ -344,6 +352,8 @@ class Atom(AbstractBlock): self.value = value def accept(self, visitor): + if visitor.visited(self): + return self.jt.accept(visitor) self.jf.accept(visitor) visitor.visit(self) @@ -352,6 +362,15 @@ class Atom(AbstractBlock): class AbstractVisitor(abc.ABC): """An abstract visitor.""" + def __init__(self): + self._visited = set() + + def visited(self, block): + if id(block) in self._visited: + return True + self._visited.add(id(block)) + return False + def process(self, block): block.accept(self) return block @@ -437,6 +456,7 @@ class CopyingVisitor(AbstractVisitor): """A visitor that copies Blocks.""" def __init__(self): + super().__init__() self._mapping = {} def process(self, block): @@ -445,54 +465,44 @@ class CopyingVisitor(AbstractVisitor): return self._mapping[id(block)] def visitKillProcess(self, block): - if id(block) in self._mapping: - return + assert id(block) not in self._mapping self._mapping[id(block)] = KillProcess() def visitKillThread(self, block): - if id(block) in self._mapping: - return + assert id(block) not in self._mapping self._mapping[id(block)] = KillThread() def visitTrap(self, block): - if id(block) in self._mapping: - return + assert id(block) not in self._mapping self._mapping[id(block)] = Trap() def visitReturnErrno(self, block): - if id(block) in self._mapping: - return + assert id(block) not in self._mapping self._mapping[id(block)] = ReturnErrno(block.errno) def visitTrace(self, block): - if id(block) in self._mapping: - return + assert id(block) not in self._mapping self._mapping[id(block)] = Trace() def visitLog(self, block): - if id(block) in self._mapping: - return + assert id(block) not in self._mapping self._mapping[id(block)] = Log() def visitAllow(self, block): - if id(block) in self._mapping: - return + assert id(block) not in self._mapping self._mapping[id(block)] = Allow() def visitBasicBlock(self, block): - if id(block) in self._mapping: - return + assert id(block) not in self._mapping self._mapping[id(block)] = BasicBlock(block.instructions) def visitValidateArch(self, block): - if id(block) in self._mapping: - return + assert id(block) not in self._mapping self._mapping[id(block)] = ValidateArch( block.arch, self._mapping[id(block.next_block)]) def visitSyscallEntry(self, block): - if id(block) in self._mapping: - return + assert id(block) not in self._mapping self._mapping[id(block)] = SyscallEntry( block.syscall_number, self._mapping[id(block.jt)], @@ -500,15 +510,13 @@ class CopyingVisitor(AbstractVisitor): op=block.op) def visitWideAtom(self, block): - if id(block) in self._mapping: - return + assert id(block) not in self._mapping self._mapping[id(block)] = WideAtom( block.arg_offset, block.op, block.value, self._mapping[id( block.jt)], self._mapping[id(block.jf)]) def visitAtom(self, block): - if id(block) in self._mapping: - return + assert id(block) not in self._mapping self._mapping[id(block)] = Atom(block.arg_index, block.op, block.value, self._mapping[id(block.jt)], self._mapping[id(block.jf)]) @@ -522,8 +530,7 @@ class LoweringVisitor(CopyingVisitor): self._bits = arch.bits def visitAtom(self, block): - if id(block) in self._mapping: - return + assert id(block) not in self._mapping lo = block.value & 0xFFFFFFFF hi = (block.value >> 32) & 0xFFFFFFFF @@ -582,6 +589,7 @@ class FlatteningVisitor: """A visitor that flattens a DAG of Block objects.""" def __init__(self, *, arch, kill_action): + self._visited = set() self._kill_action = kill_action self._instructions = [] self._arch = arch @@ -621,9 +629,14 @@ class FlatteningVisitor: SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance), ] + def visited(self, block): + if id(block) in self._visited: + return True + self._visited.add(id(block)) + return False + def visit(self, block): - if id(block) in self._offsets: - return + assert id(block) not in self._offsets if isinstance(block, BasicBlock): instructions = block.instructions @@ -657,8 +670,15 @@ class ArgFilterForwardingVisitor: """A visitor that forwards visitation to all arg filters.""" def __init__(self, visitor): + self._visited = set() self.visitor = visitor + def visited(self, block): + if id(block) in self._visited: + return True + self._visited.add(id(block)) + return False + def visit(self, block): # All arg filters are BasicBlocks. if not isinstance(block, BasicBlock): diff --git a/tools/compiler_unittest.py b/tools/compiler_unittest.py index 3cc0e6a..ae4c1e5 100755 --- a/tools/compiler_unittest.py +++ b/tools/compiler_unittest.py @@ -460,6 +460,40 @@ class CompileFileTests(unittest.TestCase): self.arch.syscalls[name], number + 1)[1], 'KILL_PROCESS') + def test_compile_huge_filter(self): + """Ensure jumps while compiling a huge policy are still valid.""" + # This is intended to force cases where the AST visitation would result + # in a combinatorial explosion of calls to Block.accept(). An optimized + # implementation should be O(n). + num_entries = 128 + syscalls = {} + # Here we force every single filter to be distinct. Otherwise the + # codegen layer will coalesce filters that compile to the same + # instructions. + policy_contents = [] + for name in random.sample(self.arch.syscalls.keys(), num_entries): + values = random.sample(range(1024), num_entries) + syscalls[name] = values + policy_contents.append( + '%s: %s' % (name, ' || '.join('arg0 == %d' % value + for value in values))) + + path = self._write_file('test.policy', '\n'.join(policy_contents)) + + program = self.compiler.compile_file( + path, + optimization_strategy=compiler.OptimizationStrategy.LINEAR, + kill_action=bpf.KillProcess()) + for name, values in syscalls.items(): + self.assertEqual( + bpf.simulate(program.instructions, + self.arch.arch_nr, self.arch.syscalls[name], + random.choice(values))[1], 'ALLOW') + self.assertEqual( + bpf.simulate(program.instructions, self.arch.arch_nr, + self.arch.syscalls[name], 1025)[1], + 'KILL_PROCESS') + if __name__ == '__main__': unittest.main() |