aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuis Hector Chavez <lhchavez@google.com>2019-03-09 18:46:20 -0800
committerLuis Hector Chavez <lhchavez@google.com>2019-03-27 08:02:50 -0700
commit3b41ae3db9adecad11ed519f891e784c16f21393 (patch)
treecee5ad9c88f68df3ceedd347910a052dcf48a117
parent89a2710839563aaaeefd19906f4bb4431219dac7 (diff)
downloadminijail-3b41ae3db9adecad11ed519f891e784c16f21393.tar.gz
tools/compile_seccomp_policy: Optimize AST visitation
Since the AST nodes form a graph (and NOT a tree), the Block.accept() method was running into a combinatorial explosion of calls to Block.accept(). Some of the Visitor.visit() methods had some logic to prevent processing a Block more than once, but it would still be invoked multiple times. This change adds a Visitor.visited() method so that Block.accept() can stop unnecessary recursion, so that Block.accept() will be evaluated exactly once per Visitor-Block pair, thus strengthening the previous guarantee of Visitor.visit() being evaluated just once (and changing the visitation guard to an assert). This also adds a unit test to prevent this from regression. Bug: chromium:856315 Test: ./tools/compiler_unittest.py Change-Id: I79ba9f66ee63c51ada18175519c2cdab32efcd5b
-rw-r--r--tools/bpf.py76
-rwxr-xr-xtools/compiler_unittest.py34
2 files changed, 82 insertions, 28 deletions
diff --git a/tools/bpf.py b/tools/bpf.py
index bd7007e..75db502 100644
--- a/tools/bpf.py
+++ b/tools/bpf.py
@@ -172,6 +172,8 @@ class BasicBlock(AbstractBlock):
self._instructions = instructions
def accept(self, visitor):
+ if visitor.visited(self):
+ return
visitor.visit(self)
@property
@@ -251,6 +253,8 @@ class ValidateArch(AbstractBlock):
self.next_block = next_block
def accept(self, visitor):
+ if visitor.visited(self):
+ return
self.next_block.accept(visitor)
visitor.visit(self)
@@ -274,6 +278,8 @@ class SyscallEntry(AbstractBlock):
return False
def accept(self, visitor):
+ if visitor.visited(self):
+ return
self.jt.accept(visitor)
self.jf.accept(visitor)
visitor.visit(self)
@@ -299,6 +305,8 @@ class WideAtom(AbstractBlock):
self.jf = jf
def accept(self, visitor):
+ if visitor.visited(self):
+ return
self.jt.accept(visitor)
self.jf.accept(visitor)
visitor.visit(self)
@@ -344,6 +352,8 @@ class Atom(AbstractBlock):
self.value = value
def accept(self, visitor):
+ if visitor.visited(self):
+ return
self.jt.accept(visitor)
self.jf.accept(visitor)
visitor.visit(self)
@@ -352,6 +362,15 @@ class Atom(AbstractBlock):
class AbstractVisitor(abc.ABC):
"""An abstract visitor."""
+ def __init__(self):
+ self._visited = set()
+
+ def visited(self, block):
+ if id(block) in self._visited:
+ return True
+ self._visited.add(id(block))
+ return False
+
def process(self, block):
block.accept(self)
return block
@@ -437,6 +456,7 @@ class CopyingVisitor(AbstractVisitor):
"""A visitor that copies Blocks."""
def __init__(self):
+ super().__init__()
self._mapping = {}
def process(self, block):
@@ -445,54 +465,44 @@ class CopyingVisitor(AbstractVisitor):
return self._mapping[id(block)]
def visitKillProcess(self, block):
- if id(block) in self._mapping:
- return
+ assert id(block) not in self._mapping
self._mapping[id(block)] = KillProcess()
def visitKillThread(self, block):
- if id(block) in self._mapping:
- return
+ assert id(block) not in self._mapping
self._mapping[id(block)] = KillThread()
def visitTrap(self, block):
- if id(block) in self._mapping:
- return
+ assert id(block) not in self._mapping
self._mapping[id(block)] = Trap()
def visitReturnErrno(self, block):
- if id(block) in self._mapping:
- return
+ assert id(block) not in self._mapping
self._mapping[id(block)] = ReturnErrno(block.errno)
def visitTrace(self, block):
- if id(block) in self._mapping:
- return
+ assert id(block) not in self._mapping
self._mapping[id(block)] = Trace()
def visitLog(self, block):
- if id(block) in self._mapping:
- return
+ assert id(block) not in self._mapping
self._mapping[id(block)] = Log()
def visitAllow(self, block):
- if id(block) in self._mapping:
- return
+ assert id(block) not in self._mapping
self._mapping[id(block)] = Allow()
def visitBasicBlock(self, block):
- if id(block) in self._mapping:
- return
+ assert id(block) not in self._mapping
self._mapping[id(block)] = BasicBlock(block.instructions)
def visitValidateArch(self, block):
- if id(block) in self._mapping:
- return
+ assert id(block) not in self._mapping
self._mapping[id(block)] = ValidateArch(
block.arch, self._mapping[id(block.next_block)])
def visitSyscallEntry(self, block):
- if id(block) in self._mapping:
- return
+ assert id(block) not in self._mapping
self._mapping[id(block)] = SyscallEntry(
block.syscall_number,
self._mapping[id(block.jt)],
@@ -500,15 +510,13 @@ class CopyingVisitor(AbstractVisitor):
op=block.op)
def visitWideAtom(self, block):
- if id(block) in self._mapping:
- return
+ assert id(block) not in self._mapping
self._mapping[id(block)] = WideAtom(
block.arg_offset, block.op, block.value, self._mapping[id(
block.jt)], self._mapping[id(block.jf)])
def visitAtom(self, block):
- if id(block) in self._mapping:
- return
+ assert id(block) not in self._mapping
self._mapping[id(block)] = Atom(block.arg_index, block.op, block.value,
self._mapping[id(block.jt)],
self._mapping[id(block.jf)])
@@ -522,8 +530,7 @@ class LoweringVisitor(CopyingVisitor):
self._bits = arch.bits
def visitAtom(self, block):
- if id(block) in self._mapping:
- return
+ assert id(block) not in self._mapping
lo = block.value & 0xFFFFFFFF
hi = (block.value >> 32) & 0xFFFFFFFF
@@ -582,6 +589,7 @@ class FlatteningVisitor:
"""A visitor that flattens a DAG of Block objects."""
def __init__(self, *, arch, kill_action):
+ self._visited = set()
self._kill_action = kill_action
self._instructions = []
self._arch = arch
@@ -621,9 +629,14 @@ class FlatteningVisitor:
SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance),
]
+ def visited(self, block):
+ if id(block) in self._visited:
+ return True
+ self._visited.add(id(block))
+ return False
+
def visit(self, block):
- if id(block) in self._offsets:
- return
+ assert id(block) not in self._offsets
if isinstance(block, BasicBlock):
instructions = block.instructions
@@ -657,8 +670,15 @@ class ArgFilterForwardingVisitor:
"""A visitor that forwards visitation to all arg filters."""
def __init__(self, visitor):
+ self._visited = set()
self.visitor = visitor
+ def visited(self, block):
+ if id(block) in self._visited:
+ return True
+ self._visited.add(id(block))
+ return False
+
def visit(self, block):
# All arg filters are BasicBlocks.
if not isinstance(block, BasicBlock):
diff --git a/tools/compiler_unittest.py b/tools/compiler_unittest.py
index 3cc0e6a..ae4c1e5 100755
--- a/tools/compiler_unittest.py
+++ b/tools/compiler_unittest.py
@@ -460,6 +460,40 @@ class CompileFileTests(unittest.TestCase):
self.arch.syscalls[name], number + 1)[1],
'KILL_PROCESS')
+ def test_compile_huge_filter(self):
+ """Ensure jumps while compiling a huge policy are still valid."""
+ # This is intended to force cases where the AST visitation would result
+ # in a combinatorial explosion of calls to Block.accept(). An optimized
+ # implementation should be O(n).
+ num_entries = 128
+ syscalls = {}
+ # Here we force every single filter to be distinct. Otherwise the
+ # codegen layer will coalesce filters that compile to the same
+ # instructions.
+ policy_contents = []
+ for name in random.sample(self.arch.syscalls.keys(), num_entries):
+ values = random.sample(range(1024), num_entries)
+ syscalls[name] = values
+ policy_contents.append(
+ '%s: %s' % (name, ' || '.join('arg0 == %d' % value
+ for value in values)))
+
+ path = self._write_file('test.policy', '\n'.join(policy_contents))
+
+ program = self.compiler.compile_file(
+ path,
+ optimization_strategy=compiler.OptimizationStrategy.LINEAR,
+ kill_action=bpf.KillProcess())
+ for name, values in syscalls.items():
+ self.assertEqual(
+ bpf.simulate(program.instructions,
+ self.arch.arch_nr, self.arch.syscalls[name],
+ random.choice(values))[1], 'ALLOW')
+ self.assertEqual(
+ bpf.simulate(program.instructions, self.arch.arch_nr,
+ self.arch.syscalls[name], 1025)[1],
+ 'KILL_PROCESS')
+
if __name__ == '__main__':
unittest.main()