diff options
author | Bill Rassieur <rassb@google.com> | 2019-03-29 04:15:40 +0000 |
---|---|---|
committer | Bill Rassieur <rassb@google.com> | 2019-03-29 04:15:40 +0000 |
commit | 3cd28db90a6421187338e6c44b4b884af8790dd6 (patch) | |
tree | 339b9122dc1427f93693d451c5c494eb67fb85de | |
parent | 788d5b1f188aeb6131746dbe88af497424102b8b (diff) | |
parent | a223091ff4a7cf746dd16312c483241ef602a48c (diff) | |
download | minijail-3cd28db90a6421187338e6c44b4b884af8790dd6.tar.gz |
Merge master@5406228 into git_qt-dev-plus-aosp.
Change-Id: I5dc2e5fda83fd1350b0383b373f22cb1dc0d198c
BUG: 129345239
-rw-r--r-- | minijail0.5 | 2 | ||||
-rw-r--r-- | syscall_filter.c | 41 | ||||
-rw-r--r-- | syscall_filter_unittest.cc | 35 | ||||
-rw-r--r-- | tools/bpf.py | 653 | ||||
-rw-r--r-- | tools/compiler.py | 94 | ||||
-rwxr-xr-x | tools/compiler_unittest.py | 276 | ||||
-rw-r--r-- | tools/parser.py | 483 | ||||
-rwxr-xr-x | tools/parser_unittest.py | 402 | ||||
-rw-r--r-- | tools/testdata/arch_64.json | 5 |
9 files changed, 1987 insertions, 4 deletions
diff --git a/minijail0.5 b/minijail0.5 index bf6d0d7..f7b411a 100644 --- a/minijail0.5 +++ b/minijail0.5 @@ -34,6 +34,8 @@ The policy file supplied to the \fB-S\fR argument supports the following syntax: \fB<empty line>\fR \fB# any single line comment\fR +Long lines may be broken up using \\ at the end. + A policy that emulates \fBseccomp\fR(2) in mode 1 may look like: read: 1 write: 1 diff --git a/syscall_filter.c b/syscall_filter.c index c1526a4..049797a 100644 --- a/syscall_filter.c +++ b/syscall_filter.c @@ -499,6 +499,45 @@ int parse_include_statement(struct parser_state *state, char *policy_line, return 0; } +/* + * This is like getline() but supports line wrapping with \. + */ +static ssize_t getmultiline(char **lineptr, size_t *n, FILE *stream) +{ + ssize_t ret = getline(lineptr, n, stream); + if (ret < 0) + return ret; + + char *line = *lineptr; + /* Eat the newline to make processing below easier. */ + if (ret > 0 && line[ret - 1] == '\n') + line[--ret] = '\0'; + + /* If the line doesn't end in a backslash, we're done. */ + if (ret <= 0 || line[ret - 1] != '\\') + return ret; + + /* This line ends in a backslash. Get the nextline. */ + line[--ret] = '\0'; + size_t next_n = 0; + char *next_line = NULL; + ssize_t next_ret = getmultiline(&next_line, &next_n, stream); + if (next_ret == -1) { + free(next_line); + /* We couldn't fully read the line, so return an error. */ + return -1; + } + + /* Merge the lines. */ + *n = ret + next_ret + 2; + line = realloc(line, *n); + line[ret] = ' '; + memcpy(&line[ret + 1], next_line, next_ret + 1); + free(next_line); + *lineptr = line; + return ret; +} + int compile_file(const char *filename, FILE *policy_file, struct filter_block *head, struct filter_block **arg_blocks, struct bpf_labels *labels, int use_ret_trap, int allow_logging, @@ -522,7 +561,7 @@ int compile_file(const char *filename, FILE *policy_file, size_t len = 0; int ret = 0; - while (getline(&line, &len, policy_file) != -1) { + while (getmultiline(&line, &len, policy_file) != -1) { char *policy_line = line; policy_line = strip(policy_line); diff --git a/syscall_filter_unittest.cc b/syscall_filter_unittest.cc index 8a0b19a..aca5f54 100644 --- a/syscall_filter_unittest.cc +++ b/syscall_filter_unittest.cc @@ -1261,6 +1261,41 @@ TEST_F(FileTest, seccomp_read) { EXPECT_EQ(curr_block->next, nullptr); } +TEST_F(FileTest, multiline) { + std::string policy = + "read:\\\n1\n" + "openat:arg0 in\\\n5"; + + const int LABEL_ID = 0; + + FILE* policy_file = write_policy_to_pipe(policy); + ASSERT_NE(policy_file, nullptr); + int res = test_compile_file("policy", policy_file, head_, &arg_blocks_, + &labels_); + fclose(policy_file); + + /* + * Policy should be valid. + */ + ASSERT_EQ(res, 0); + + /* First block is the read. */ + struct filter_block *curr_block = head_; + ASSERT_NE(curr_block, nullptr); + EXPECT_ALLOW_SYSCALL(curr_block->instrs, __NR_read); + + /* Second block is the open. */ + curr_block = curr_block->next; + ASSERT_NE(curr_block, nullptr); + EXPECT_ALLOW_SYSCALL_ARGS(curr_block->instrs, + __NR_openat, + LABEL_ID, + JUMP_JT, + JUMP_JF); + + EXPECT_EQ(curr_block->next, nullptr); +} + TEST(FilterTest, seccomp_mode1) { struct sock_fprog actual; std::string policy = diff --git a/tools/bpf.py b/tools/bpf.py new file mode 100644 index 0000000..e89e93f --- /dev/null +++ b/tools/bpf.py @@ -0,0 +1,653 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (C) 2018 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tools to interact with BPF programs.""" + +import abc +import collections +import struct + +# This comes from syscall(2). Most architectures only support passing 6 args to +# syscalls, but ARM supports passing 7. +MAX_SYSCALL_ARGUMENTS = 7 + +# The following fields were copied from <linux/bpf_common.h>: + +# Instruction classes +BPF_LD = 0x00 +BPF_LDX = 0x01 +BPF_ST = 0x02 +BPF_STX = 0x03 +BPF_ALU = 0x04 +BPF_JMP = 0x05 +BPF_RET = 0x06 +BPF_MISC = 0x07 + +# LD/LDX fields. +# Size +BPF_W = 0x00 +BPF_H = 0x08 +BPF_B = 0x10 +# Mode +BPF_IMM = 0x00 +BPF_ABS = 0x20 +BPF_IND = 0x40 +BPF_MEM = 0x60 +BPF_LEN = 0x80 +BPF_MSH = 0xa0 + +# JMP fields. +BPF_JA = 0x00 +BPF_JEQ = 0x10 +BPF_JGT = 0x20 +BPF_JGE = 0x30 +BPF_JSET = 0x40 + +# Source +BPF_K = 0x00 +BPF_X = 0x08 + +BPF_MAXINSNS = 4096 + +# The following fields were copied from <linux/seccomp.h>: + +SECCOMP_RET_KILL_PROCESS = 0x80000000 +SECCOMP_RET_KILL_THREAD = 0x00000000 +SECCOMP_RET_TRAP = 0x00030000 +SECCOMP_RET_ERRNO = 0x00050000 +SECCOMP_RET_TRACE = 0x7ff00000 +SECCOMP_RET_LOG = 0x7ffc0000 +SECCOMP_RET_ALLOW = 0x7fff0000 + +SECCOMP_RET_ACTION_FULL = 0xffff0000 +SECCOMP_RET_DATA = 0x0000ffff + + +def arg_offset(arg_index, hi=False): + """Return the BPF_LD|BPF_W|BPF_ABS addressing-friendly register offset.""" + offsetof_args = 4 + 4 + 8 + arg_width = 8 + return offsetof_args + arg_width * arg_index + (arg_width // 2) * hi + + +def simulate(instructions, arch, syscall_number, *args): + """Simulate a BPF program with the given arguments.""" + args = ((args + (0, ) * + (MAX_SYSCALL_ARGUMENTS - len(args)))[:MAX_SYSCALL_ARGUMENTS]) + input_memory = struct.pack('IIQ' + 'Q' * MAX_SYSCALL_ARGUMENTS, + syscall_number, arch, 0, *args) + + register = 0 + program_counter = 0 + cost = 0 + while program_counter < len(instructions): + ins = instructions[program_counter] + program_counter += 1 + cost += 1 + if ins.code == BPF_LD | BPF_W | BPF_ABS: + register = struct.unpack('I', input_memory[ins.k:ins.k + 4])[0] + elif ins.code == BPF_JMP | BPF_JA | BPF_K: + program_counter += ins.k + elif ins.code == BPF_JMP | BPF_JEQ | BPF_K: + if register == ins.k: + program_counter += ins.jt + else: + program_counter += ins.jf + elif ins.code == BPF_JMP | BPF_JGT | BPF_K: + if register > ins.k: + program_counter += ins.jt + else: + program_counter += ins.jf + elif ins.code == BPF_JMP | BPF_JGE | BPF_K: + if register >= ins.k: + program_counter += ins.jt + else: + program_counter += ins.jf + elif ins.code == BPF_JMP | BPF_JSET | BPF_K: + if register & ins.k != 0: + program_counter += ins.jt + else: + program_counter += ins.jf + elif ins.code == BPF_RET: + if ins.k == SECCOMP_RET_KILL_PROCESS: + return (cost, 'KILL_PROCESS') + if ins.k == SECCOMP_RET_KILL_THREAD: + return (cost, 'KILL_THREAD') + if ins.k == SECCOMP_RET_TRAP: + return (cost, 'TRAP') + if (ins.k & SECCOMP_RET_ACTION_FULL) == SECCOMP_RET_ERRNO: + return (cost, 'ERRNO', ins.k & SECCOMP_RET_DATA) + if ins.k == SECCOMP_RET_TRACE: + return (cost, 'TRACE') + if ins.k == SECCOMP_RET_LOG: + return (cost, 'LOG') + if ins.k == SECCOMP_RET_ALLOW: + return (cost, 'ALLOW') + raise Exception('unknown return %#x' % ins.k) + else: + raise Exception('unknown instruction %r' % (ins, )) + raise Exception('out-of-bounds') + + +class SockFilter( + collections.namedtuple('SockFilter', ['code', 'jt', 'jf', 'k'])): + """A representation of struct sock_filter.""" + + __slots__ = () + + def encode(self): + """Return an encoded version of the SockFilter.""" + return struct.pack('HBBI', self.code, self.jt, self.jf, self.k) + + +class AbstractBlock(abc.ABC): + """A class that implements the visitor pattern.""" + + def __init__(self): + super().__init__() + + @abc.abstractmethod + def accept(self, visitor): + pass + + +class BasicBlock(AbstractBlock): + """A concrete implementation of AbstractBlock that has been compiled.""" + + def __init__(self, instructions): + super().__init__() + self._instructions = instructions + + def accept(self, visitor): + visitor.visit(self) + + @property + def instructions(self): + return self._instructions + + @property + def opcodes(self): + return b''.join(i.encode() for i in self._instructions) + + def __eq__(self, o): + if not isinstance(o, BasicBlock): + return False + return self._instructions == o._instructions + + +class KillProcess(BasicBlock): + """A BasicBlock that unconditionally returns KILL_PROCESS.""" + + def __init__(self): + super().__init__( + [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_PROCESS)]) + + +class KillThread(BasicBlock): + """A BasicBlock that unconditionally returns KILL_THREAD.""" + + def __init__(self): + super().__init__( + [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_THREAD)]) + + +class Trap(BasicBlock): + """A BasicBlock that unconditionally returns TRAP.""" + + def __init__(self): + super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRAP)]) + + +class Trace(BasicBlock): + """A BasicBlock that unconditionally returns TRACE.""" + + def __init__(self): + super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRACE)]) + + +class Log(BasicBlock): + """A BasicBlock that unconditionally returns LOG.""" + + def __init__(self): + super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_LOG)]) + + +class ReturnErrno(BasicBlock): + """A BasicBlock that unconditionally returns the specified errno.""" + + def __init__(self, errno): + super().__init__([ + SockFilter(BPF_RET, 0x00, 0x00, + SECCOMP_RET_ERRNO | (errno & SECCOMP_RET_DATA)) + ]) + self.errno = errno + + +class Allow(BasicBlock): + """A BasicBlock that unconditionally returns ALLOW.""" + + def __init__(self): + super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_ALLOW)]) + + +class ValidateArch(AbstractBlock): + """An AbstractBlock that validates the architecture.""" + + def __init__(self, next_block): + super().__init__() + self.next_block = next_block + + def accept(self, visitor): + self.next_block.accept(visitor) + visitor.visit(self) + + +class SyscallEntry(AbstractBlock): + """An abstract block that represents a syscall comparison in a DAG.""" + + def __init__(self, syscall_number, jt, jf, *, op=BPF_JEQ): + super().__init__() + self.op = op + self.syscall_number = syscall_number + self.jt = jt + self.jf = jf + + def __lt__(self, o): + # Defined because we want to compare tuples that contain SyscallEntries. + return False + + def __gt__(self, o): + # Defined because we want to compare tuples that contain SyscallEntries. + return False + + def accept(self, visitor): + self.jt.accept(visitor) + self.jf.accept(visitor) + visitor.visit(self) + + def __lt__(self, o): + # Defined because we want to compare tuples that contain SyscallEntries. + return False + + def __gt__(self, o): + # Defined because we want to compare tuples that contain SyscallEntries. + return False + + +class WideAtom(AbstractBlock): + """A BasicBlock that represents a 32-bit wide atom.""" + + def __init__(self, arg_offset, op, value, jt, jf): + super().__init__() + self.arg_offset = arg_offset + self.op = op + self.value = value + self.jt = jt + self.jf = jf + + def accept(self, visitor): + self.jt.accept(visitor) + self.jf.accept(visitor) + visitor.visit(self) + + +class Atom(AbstractBlock): + """A BasicBlock that represents an atom (a simple comparison operation).""" + + def __init__(self, arg_index, op, value, jt, jf): + super().__init__() + if op == '==': + op = BPF_JEQ + elif op == '!=': + op = BPF_JEQ + jt, jf = jf, jt + elif op == '>': + op = BPF_JGT + elif op == '<=': + op = BPF_JGT + jt, jf = jf, jt + elif op == '>=': + op = BPF_JGE + elif op == '<': + op = BPF_JGE + jt, jf = jf, jt + elif op == '&': + op = BPF_JSET + elif op == 'in': + op = BPF_JSET + # The mask is negated, so the comparison will be true when the + # argument includes a flag that wasn't listed in the original + # (non-negated) mask. This would be the failure case, so we switch + # |jt| and |jf|. + value = (~value) & ((1 << 64) - 1) + jt, jf = jf, jt + else: + raise Exception('Unknown operator %s' % op) + + self.arg_index = arg_index + self.op = op + self.jt = jt + self.jf = jf + self.value = value + + def accept(self, visitor): + self.jt.accept(visitor) + self.jf.accept(visitor) + visitor.visit(self) + + +class AbstractVisitor(abc.ABC): + """An abstract visitor.""" + + def process(self, block): + block.accept(self) + return block + + def visit(self, block): + if isinstance(block, KillProcess): + self.visitKillProcess(block) + elif isinstance(block, KillThread): + self.visitKillThread(block) + elif isinstance(block, Trap): + self.visitTrap(block) + elif isinstance(block, ReturnErrno): + self.visitReturnErrno(block) + elif isinstance(block, Trace): + self.visitTrace(block) + elif isinstance(block, Log): + self.visitLog(block) + elif isinstance(block, Allow): + self.visitAllow(block) + elif isinstance(block, BasicBlock): + self.visitBasicBlock(block) + elif isinstance(block, ValidateArch): + self.visitValidateArch(block) + elif isinstance(block, SyscallEntry): + self.visitSyscallEntry(block) + elif isinstance(block, WideAtom): + self.visitWideAtom(block) + elif isinstance(block, Atom): + self.visitAtom(block) + else: + raise Exception('Unknown block type: %r' % block) + + @abc.abstractmethod + def visitKillProcess(self, block): + pass + + @abc.abstractmethod + def visitKillThread(self, block): + pass + + @abc.abstractmethod + def visitTrap(self, block): + pass + + @abc.abstractmethod + def visitReturnErrno(self, block): + pass + + @abc.abstractmethod + def visitTrace(self, block): + pass + + @abc.abstractmethod + def visitLog(self, block): + pass + + @abc.abstractmethod + def visitAllow(self, block): + pass + + @abc.abstractmethod + def visitBasicBlock(self, block): + pass + + @abc.abstractmethod + def visitValidateArch(self, block): + pass + + @abc.abstractmethod + def visitSyscallEntry(self, block): + pass + + @abc.abstractmethod + def visitWideAtom(self, block): + pass + + @abc.abstractmethod + def visitAtom(self, block): + pass + + +class CopyingVisitor(AbstractVisitor): + """A visitor that copies Blocks.""" + + def __init__(self): + self._mapping = {} + + def process(self, block): + self._mapping = {} + block.accept(self) + return self._mapping[id(block)] + + def visitKillProcess(self, block): + if id(block) in self._mapping: + return + self._mapping[id(block)] = KillProcess() + + def visitKillThread(self, block): + if id(block) in self._mapping: + return + self._mapping[id(block)] = KillThread() + + def visitTrap(self, block): + if id(block) in self._mapping: + return + self._mapping[id(block)] = Trap() + + def visitReturnErrno(self, block): + if id(block) in self._mapping: + return + self._mapping[id(block)] = ReturnErrno(block.errno) + + def visitTrace(self, block): + if id(block) in self._mapping: + return + self._mapping[id(block)] = Trace() + + def visitLog(self, block): + if id(block) in self._mapping: + return + self._mapping[id(block)] = Log() + + def visitAllow(self, block): + if id(block) in self._mapping: + return + self._mapping[id(block)] = Allow() + + def visitBasicBlock(self, block): + if id(block) in self._mapping: + return + self._mapping[id(block)] = BasicBlock(block.instructions) + + def visitValidateArch(self, block): + if id(block) in self._mapping: + return + self._mapping[id(block)] = ValidateArch( + block.arch, self._mapping[id(block.next_block)]) + + def visitSyscallEntry(self, block): + if id(block) in self._mapping: + return + self._mapping[id(block)] = SyscallEntry( + block.syscall_number, + self._mapping[id(block.jt)], + self._mapping[id(block.jf)], + op=block.op) + + def visitWideAtom(self, block): + if id(block) in self._mapping: + return + self._mapping[id(block)] = WideAtom( + block.arg_offset, block.op, block.value, self._mapping[id( + block.jt)], self._mapping[id(block.jf)]) + + def visitAtom(self, block): + if id(block) in self._mapping: + return + self._mapping[id(block)] = Atom(block.arg_index, block.op, block.value, + self._mapping[id(block.jt)], + self._mapping[id(block.jf)]) + + +class LoweringVisitor(CopyingVisitor): + """A visitor that lowers Atoms into WideAtoms.""" + + def __init__(self, *, arch): + super().__init__() + self._bits = arch.bits + + def visitAtom(self, block): + if id(block) in self._mapping: + return + + lo = block.value & 0xFFFFFFFF + hi = (block.value >> 32) & 0xFFFFFFFF + + lo_block = WideAtom( + arg_offset(block.arg_index, False), block.op, lo, + self._mapping[id(block.jt)], self._mapping[id(block.jf)]) + + if self._bits == 32: + self._mapping[id(block)] = lo_block + return + + if block.op in (BPF_JGE, BPF_JGT): + # hi_1,lo_1 <op> hi_2,lo_2 + # + # hi_1 > hi_2 || hi_1 == hi_2 && lo_1 <op> lo_2 + if hi == 0: + # Special case: it's not needed to check whether |hi_1 == hi_2|, + # because it's true iff the JGT test fails. + self._mapping[id(block)] = WideAtom( + arg_offset(block.arg_index, True), BPF_JGT, hi, + self._mapping[id(block.jt)], lo_block) + return + hi_eq_block = WideAtom( + arg_offset(block.arg_index, True), BPF_JEQ, hi, lo_block, + self._mapping[id(block.jf)]) + self._mapping[id(block)] = WideAtom( + arg_offset(block.arg_index, True), BPF_JGT, hi, + self._mapping[id(block.jt)], hi_eq_block) + return + if block.op == BPF_JSET: + # hi_1,lo_1 & hi_2,lo_2 + # + # hi_1 & hi_2 || lo_1 & lo_2 + if hi == 0: + # Special case: |hi_1 & hi_2| will never be True, so jump + # directly into the |lo_1 & lo_2| case. + self._mapping[id(block)] = lo_block + return + self._mapping[id(block)] = WideAtom( + arg_offset(block.arg_index, True), block.op, hi, + self._mapping[id(block.jt)], lo_block) + return + + assert block.op == BPF_JEQ, block.op + + # hi_1,lo_1 == hi_2,lo_2 + # + # hi_1 == hi_2 && lo_1 == lo_2 + self._mapping[id(block)] = WideAtom( + arg_offset(block.arg_index, True), block.op, hi, lo_block, + self._mapping[id(block.jf)]) + + +class FlatteningVisitor: + """A visitor that flattens a DAG of Block objects.""" + + def __init__(self, *, arch, kill_action): + self._kill_action = kill_action + self._instructions = [] + self._arch = arch + self._offsets = {} + + @property + def result(self): + return BasicBlock(self._instructions) + + def _distance(self, block): + distance = self._offsets[id(block)] + len(self._instructions) + assert distance >= 0 + return distance + + def _emit_load_arg(self, offset): + return [SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, offset)] + + def _emit_jmp(self, op, value, jt_distance, jf_distance): + if jt_distance < 0x100 and jf_distance < 0x100: + return [ + SockFilter(BPF_JMP | op | BPF_K, jt_distance, jf_distance, + value), + ] + if jt_distance + 1 < 0x100: + return [ + SockFilter(BPF_JMP | op | BPF_K, jt_distance + 1, 0, value), + SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance), + ] + if jf_distance + 1 < 0x100: + return [ + SockFilter(BPF_JMP | op | BPF_K, 0, jf_distance + 1, value), + SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance), + ] + return [ + SockFilter(BPF_JMP | op | BPF_K, 0, 1, value), + SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance + 1), + SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance), + ] + + def visit(self, block): + if id(block) in self._offsets: + return + + if isinstance(block, BasicBlock): + instructions = block.instructions + elif isinstance(block, ValidateArch): + instructions = [ + SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 4), + SockFilter(BPF_JMP | BPF_JEQ | BPF_K, + self._distance(block.next_block) + 1, 0, + self._arch.arch_nr), + ] + self._kill_action.instructions + [ + SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 0), + ] + elif isinstance(block, SyscallEntry): + instructions = self._emit_jmp(block.op, block.syscall_number, + self._distance(block.jt), + self._distance(block.jf)) + elif isinstance(block, WideAtom): + instructions = ( + self._emit_load_arg(block.arg_offset) + self._emit_jmp( + block.op, block.value, self._distance(block.jt), + self._distance(block.jf))) + else: + raise Exception('Unknown block type: %r' % block) + + self._instructions = instructions + self._instructions + self._offsets[id(block)] = -len(self._instructions) + return diff --git a/tools/compiler.py b/tools/compiler.py new file mode 100644 index 0000000..96800f1 --- /dev/null +++ b/tools/compiler.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (C) 2018 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A BPF compiler for the Minijail policy file.""" + +from __future__ import print_function + +import bpf +import parser # pylint: disable=wrong-import-order + + +class SyscallPolicyEntry: + """The parsed version of a seccomp policy line.""" + + def __init__(self, name, number, frequency): + self.name = name + self.number = number + self.frequency = frequency + self.accumulated = 0 + self.filter = None + + def __repr__(self): + return ('SyscallPolicyEntry<name: %s, number: %d, ' + 'frequency: %d, filter: %r>') % (self.name, self.number, + self.frequency, + self.filter.instructions + if self.filter else None) + + def simulate(self, arch, syscall_number, *args): + """Simulate the policy with the given arguments.""" + if not self.filter: + return (0, 'ALLOW') + return bpf.simulate(self.filter.instructions, arch, syscall_number, + *args) + + +class PolicyCompiler: + """A parser for the Minijail seccomp policy file format.""" + + def __init__(self, arch): + self._arch = arch + + def compile_filter_statement(self, filter_statement, *, kill_action): + """Compile one parser.FilterStatement into BPF.""" + policy_entry = SyscallPolicyEntry(filter_statement.syscall.name, + filter_statement.syscall.number, + filter_statement.frequency) + # In each step of the way, the false action is the one that is taken if + # the immediate boolean condition does not match. This means that the + # false action taken here is the one that applies if the whole + # expression fails to match. + false_action = filter_statement.filters[-1].action + if false_action == bpf.Allow(): + return policy_entry + # We then traverse the list of filters backwards since we want + # the root of the DAG to be the very first boolean operation in + # the filter chain. + for filt in filter_statement.filters[:-1][::-1]: + for disjunction in filt.expression: + # This is the jump target of the very last comparison in the + # conjunction. Given that any conjunction that succeeds should + # make the whole expression succeed, make the very last + # comparison jump to the accept action if it succeeds. + true_action = filt.action + for atom in disjunction: + block = bpf.Atom(atom.argument_index, atom.op, atom.value, + true_action, false_action) + true_action = block + false_action = true_action + policy_filter = false_action + + # Lower all Atoms into WideAtoms. + lowering_visitor = bpf.LoweringVisitor(arch=self._arch) + policy_filter = lowering_visitor.process(policy_filter) + + # Flatten the IR DAG into a single BasicBlock. + flattening_visitor = bpf.FlatteningVisitor( + arch=self._arch, kill_action=kill_action) + policy_filter.accept(flattening_visitor) + policy_entry.filter = flattening_visitor.result + return policy_entry diff --git a/tools/compiler_unittest.py b/tools/compiler_unittest.py new file mode 100755 index 0000000..ba66e62 --- /dev/null +++ b/tools/compiler_unittest.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (C) 2018 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unittests for the compiler module.""" + +from __future__ import print_function + +import os +import tempfile +import unittest + +import arch +import bpf +import compiler +import parser # pylint: disable=wrong-import-order + +ARCH_64 = arch.Arch.load_from_json( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), 'testdata/arch_64.json')) + + +class CompileFilterStatementTests(unittest.TestCase): + """Tests for PolicyCompiler.compile_filter_statement.""" + + def setUp(self): + self.arch = ARCH_64 + self.compiler = compiler.PolicyCompiler(self.arch) + + def _compile(self, line): + with tempfile.NamedTemporaryFile(mode='w') as policy_file: + policy_file.write(line) + policy_file.flush() + policy_parser = parser.PolicyParser( + self.arch, kill_action=bpf.KillProcess()) + parsed_policy = policy_parser.parse_file(policy_file.name) + assert len(parsed_policy.filter_statements) == 1 + return self.compiler.compile_filter_statement( + parsed_policy.filter_statements[0], + kill_action=bpf.KillProcess()) + + def test_allow(self): + """Accept lines where the syscall is accepted unconditionally.""" + block = self._compile('read: allow') + self.assertEqual(block.filter, None) + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 0)[1], 'ALLOW') + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 1)[1], 'ALLOW') + + def test_arg0_eq_generated_code(self): + """Accept lines with an argument filter with ==.""" + block = self._compile('read: arg0 == 0x100') + # It might be a bit brittle to check the generated code in each test + # case instead of just the behavior, but there should be at least one + # test where this happens. + self.assertEqual( + block.filter.instructions, + [ + bpf.SockFilter(bpf.BPF_LD | bpf.BPF_W | bpf.BPF_ABS, 0, 0, + bpf.arg_offset(0, True)), + # Jump to KILL_PROCESS if the high word does not match. + bpf.SockFilter(bpf.BPF_JMP | bpf.BPF_JEQ | bpf.BPF_K, 0, 2, 0), + bpf.SockFilter(bpf.BPF_LD | bpf.BPF_W | bpf.BPF_ABS, 0, 0, + bpf.arg_offset(0, False)), + # Jump to KILL_PROCESS if the low word does not match. + bpf.SockFilter(bpf.BPF_JMP | bpf.BPF_JEQ | bpf.BPF_K, 1, 0, + 0x100), + bpf.SockFilter(bpf.BPF_RET, 0, 0, + bpf.SECCOMP_RET_KILL_PROCESS), + bpf.SockFilter(bpf.BPF_RET, 0, 0, bpf.SECCOMP_RET_ALLOW), + ]) + + def test_arg0_comparison_operators(self): + """Accept lines with an argument filter with comparison operators.""" + biases = (-1, 0, 1) + # For each operator, store the expectations of simulating the program + # against the constant plus each entry from the |biases| array. + cases = ( + ('==', ('KILL_PROCESS', 'ALLOW', 'KILL_PROCESS')), + ('!=', ('ALLOW', 'KILL_PROCESS', 'ALLOW')), + ('<', ('ALLOW', 'KILL_PROCESS', 'KILL_PROCESS')), + ('<=', ('ALLOW', 'ALLOW', 'KILL_PROCESS')), + ('>', ('KILL_PROCESS', 'KILL_PROCESS', 'ALLOW')), + ('>=', ('KILL_PROCESS', 'ALLOW', 'ALLOW')), + ) + for operator, expectations in cases: + block = self._compile('read: arg0 %s 0x100' % operator) + + # Check the filter's behavior. + for bias, expectation in zip(biases, expectations): + self.assertEqual( + block.simulate(self.arch.arch_nr, + self.arch.syscalls['read'], + 0x100 + bias)[1], expectation) + + def test_arg0_mask_operator(self): + """Accept lines with an argument filter with &.""" + block = self._compile('read: arg0 & 0x3') + + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 0)[1], 'KILL_PROCESS') + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 1)[1], 'ALLOW') + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 2)[1], 'ALLOW') + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 3)[1], 'ALLOW') + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 4)[1], 'KILL_PROCESS') + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 5)[1], 'ALLOW') + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 6)[1], 'ALLOW') + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 7)[1], 'ALLOW') + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 8)[1], 'KILL_PROCESS') + + def test_arg0_in_operator(self): + """Accept lines with an argument filter with in.""" + block = self._compile('read: arg0 in 0x3') + + # The 'in' operator only ensures that no bits outside the mask are set, + # which means that 0 is always allowed. + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 0)[1], 'ALLOW') + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 1)[1], 'ALLOW') + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 2)[1], 'ALLOW') + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 3)[1], 'ALLOW') + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 4)[1], 'KILL_PROCESS') + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 5)[1], 'KILL_PROCESS') + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 6)[1], 'KILL_PROCESS') + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 7)[1], 'KILL_PROCESS') + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 8)[1], 'KILL_PROCESS') + + def test_arg0_short_gt_ge_comparisons(self): + """Ensure that the short comparison optimization kicks in.""" + if self.arch.bits == 32: + return + short_constant_str = '0xdeadbeef' + short_constant = int(short_constant_str, base=0) + long_constant_str = '0xbadc0ffee0ddf00d' + long_constant = int(long_constant_str, base=0) + biases = (-1, 0, 1) + # For each operator, store the expectations of simulating the program + # against the constant plus each entry from the |biases| array. + cases = ( + ('<', ('ALLOW', 'KILL_PROCESS', 'KILL_PROCESS')), + ('<=', ('ALLOW', 'ALLOW', 'KILL_PROCESS')), + ('>', ('KILL_PROCESS', 'KILL_PROCESS', 'ALLOW')), + ('>=', ('KILL_PROCESS', 'ALLOW', 'ALLOW')), + ) + for operator, expectations in cases: + short_block = self._compile( + 'read: arg0 %s %s' % (operator, short_constant_str)) + long_block = self._compile( + 'read: arg0 %s %s' % (operator, long_constant_str)) + + # Check that the emitted code is shorter when the high word of the + # constant is zero. + self.assertLess( + len(short_block.filter.instructions), + len(long_block.filter.instructions)) + + # Check the filter's behavior. + for bias, expectation in zip(biases, expectations): + self.assertEqual( + long_block.simulate(self.arch.arch_nr, + self.arch.syscalls['read'], + long_constant + bias)[1], expectation) + self.assertEqual( + short_block.simulate( + self.arch.arch_nr, self.arch.syscalls['read'], + short_constant + bias)[1], expectation) + + def test_and_or(self): + """Accept lines with a complex expression in DNF.""" + block = self._compile('read: arg0 == 0 && arg1 == 0 || arg0 == 1') + + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], 0, + 0)[1], 'ALLOW') + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], 0, + 1)[1], 'KILL_PROCESS') + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], 1, + 0)[1], 'ALLOW') + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], 1, + 1)[1], 'ALLOW') + + def test_ret_errno(self): + """Accept lines that return errno.""" + block = self._compile('read : arg0 == 0 || arg0 == 1 ; return 1') + + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 0)[1:], ('ERRNO', 1)) + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 1)[1:], ('ERRNO', 1)) + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 2)[1], 'KILL_PROCESS') + + def test_ret_errno_unconditionally(self): + """Accept lines that return errno unconditionally.""" + block = self._compile('read: return 1') + + self.assertEqual( + block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], + 0)[1:], ('ERRNO', 1)) + + def test_mmap_write_xor_exec(self): + """Accept the idiomatic filter for mmap.""" + block = self._compile( + 'read : arg0 in ~PROT_WRITE || arg0 in ~PROT_EXEC') + + prot_exec_and_write = 6 + for prot in range(0, 0xf): + if (prot & prot_exec_and_write) == prot_exec_and_write: + self.assertEqual( + block.simulate(self.arch.arch_nr, + self.arch.syscalls['read'], prot)[1], + 'KILL_PROCESS') + else: + self.assertEqual( + block.simulate(self.arch.arch_nr, + self.arch.syscalls['read'], prot)[1], + 'ALLOW') + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/parser.py b/tools/parser.py index 05b6628..41d2f52 100644 --- a/tools/parser.py +++ b/tools/parser.py @@ -21,8 +21,12 @@ from __future__ import division from __future__ import print_function import collections +import itertools +import os.path import re +import bpf + Token = collections.namedtuple('token', ['type', 'value', 'filename', 'line', 'column']) @@ -30,7 +34,9 @@ Token = collections.namedtuple('token', _TOKEN_SPECIFICATION = ( ('COMMENT', r'#.*$'), ('WHITESPACE', r'\s+'), + ('DEFAULT', r'@default'), ('INCLUDE', r'@include'), + ('FREQUENCY', r'@frequency'), ('PATH', r'(?:\.)?/\S+'), ('NUMERIC_CONSTANT', r'-?0[xX][0-9a-fA-F]+|-?0[Oo][0-7]+|-?[0-9]+'), ('COLON', r':'), @@ -137,12 +143,51 @@ class ParserState: return tokens +Atom = collections.namedtuple('Atom', ['argument_index', 'op', 'value']) +"""A single boolean comparison within a filter expression.""" + +Filter = collections.namedtuple('Filter', ['expression', 'action']) +"""The result of parsing a DNF filter expression, with its action. + +Since the expression is in Disjunctive Normal Form, it is composed of two levels +of lists, one for disjunctions and the inner one for conjunctions. The elements +of the inner list are Atoms. +""" + +Syscall = collections.namedtuple('Syscall', ['name', 'number']) +"""A system call.""" + +ParsedFilterStatement = collections.namedtuple('ParsedFilterStatement', + ['syscalls', 'filters']) +"""The result of parsing a filter statement. + +Statements have a list of syscalls, and an associated list of filters that will +be evaluated sequentially when any of the syscalls is invoked. +""" + +FilterStatement = collections.namedtuple('FilterStatement', + ['syscall', 'frequency', 'filters']) +"""The filter list for a particular syscall. + +This is a mapping from one syscall to a list of filters that are evaluated +sequentially. The last filter is always an unconditional action. +""" + +ParsedPolicy = collections.namedtuple('ParsedPolicy', + ['default_action', 'filter_statements']) +"""The result of parsing a minijail .policy file.""" + + # pylint: disable=too-few-public-methods class PolicyParser: """A parser for the Minijail seccomp policy file format.""" - def __init__(self, arch): + def __init__(self, arch, *, kill_action, include_depth_limit=10): self._parser_states = [ParserState("<memory>")] + self._kill_action = kill_action + self._include_depth_limit = include_depth_limit + self._default_action = self._kill_action + self._frequency_mapping = collections.defaultdict(int) self._arch = arch @property @@ -228,3 +273,439 @@ class PolicyParser: else: self._parser_state.error('empty constant') return value + + # atom = argument , op , value + # ; + def _parse_atom(self, tokens): + if not tokens: + self._parser_state.error('missing argument') + argument = tokens.pop(0) + if argument.type != 'ARGUMENT': + self._parser_state.error('invalid argument', token=argument) + + if not tokens: + self._parser_state.error('missing operator') + operator = tokens.pop(0) + if operator.type != 'OP': + self._parser_state.error('invalid operator', token=operator) + + value = self.parse_value(tokens) + argument_index = int(argument.value[3:]) + if not (0 <= argument_index < bpf.MAX_SYSCALL_ARGUMENTS): + self._parser_state.error('invalid argument', token=argument) + return Atom(argument_index, operator.value, value) + + # clause = atom , [ { '&&' , atom } ] + # ; + def _parse_clause(self, tokens): + atoms = [] + while tokens: + atoms.append(self._parse_atom(tokens)) + if not tokens or tokens[0].type != 'AND': + break + tokens.pop(0) + else: + self._parser_state.error('empty clause') + return atoms + + # argument-expression = clause , [ { '||' , clause } ] + # ; + def parse_argument_expression(self, tokens): + """Parse a argument expression in Disjunctive Normal Form. + + Since BPF disallows back jumps, we build the basic blocks in reverse + order so that all the jump targets are known by the time we need to + reference them. + """ + + clauses = [] + while tokens: + clauses.append(self._parse_clause(tokens)) + if not tokens or tokens[0].type != 'OR': + break + tokens.pop(0) + else: + self._parser_state.error('empty argument expression') + return clauses + + # default-action = 'kill-process' + # | 'kill-thread' + # | 'kill' + # | 'trap' + # ; + def _parse_default_action(self, tokens): + if not tokens: + self._parser_state.error('missing default action') + action_token = tokens.pop(0) + if action_token.type != 'ACTION': + return self._parser_state.error( + 'invalid default action', token=action_token) + if action_token.value == 'kill-process': + return bpf.KillProcess() + if action_token.value == 'kill-thread': + return bpf.KillThread() + if action_token.value == 'kill': + return self._kill_action + if action_token.value == 'trap': + return bpf.Trap() + return self._parser_state.error( + 'invalid permissive default action', token=action_token) + + # action = 'allow' | '1' + # | 'kill-process' + # | 'kill-thread' + # | 'kill' + # | 'trap' + # | 'trace' + # | 'log' + # | 'return' , single-constant + # ; + def _parse_action(self, tokens): + if not tokens: + self._parser_state.error('missing action') + action_token = tokens.pop(0) + if action_token.type == 'ACTION': + if action_token.value == 'allow': + return bpf.Allow() + if action_token.value == 'kill': + return self._kill_action + if action_token.value == 'kill-process': + return bpf.KillProcess() + if action_token.value == 'kill-thread': + return bpf.KillThread() + if action_token.value == 'trap': + return bpf.Trap() + if action_token.value == 'trace': + return bpf.Trace() + if action_token.value == 'log': + return bpf.Log() + elif action_token.type == 'NUMERIC_CONSTANT': + constant = self._parse_single_constant(action_token) + if constant == 1: + return bpf.Allow() + elif action_token.type == 'RETURN': + if not tokens: + self._parser_state.error('missing return value') + return bpf.ReturnErrno(self._parse_single_constant(tokens.pop(0))) + return self._parser_state.error('invalid action', token=action_token) + + # single-filter = action + # | argument-expression , [ ';' , action ] + # ; + def _parse_single_filter(self, tokens): + if not tokens: + self._parser_state.error('missing filter') + if tokens[0].type == 'ARGUMENT': + # Only argument expressions can start with an ARGUMENT token. + argument_expression = self.parse_argument_expression(tokens) + if tokens and tokens[0].type == 'SEMICOLON': + tokens.pop(0) + action = self._parse_action(tokens) + else: + action = bpf.Allow() + return Filter(argument_expression, action) + else: + return Filter(None, self._parse_action(tokens)) + + # filter = '{' , single-filter , [ { ',' , single-filter } ] , '}' + # | single-filter + # ; + def parse_filter(self, tokens): + """Parse a filter and return a list of Filter objects.""" + if not tokens: + self._parser_state.error('missing filter') + filters = [] + if tokens[0].type == 'LBRACE': + opening_brace = tokens.pop(0) + while tokens: + filters.append(self._parse_single_filter(tokens)) + if not tokens or tokens[0].type != 'COMMA': + break + tokens.pop(0) + if not tokens or tokens[0].type != 'RBRACE': + self._parser_state.error('unclosed brace', token=opening_brace) + tokens.pop(0) + else: + filters.append(self._parse_single_filter(tokens)) + return filters + + # key-value-pair = identifier , '=', identifier , [ { ',' , identifier } ] + # ; + def _parse_key_value_pair(self, tokens): + if not tokens: + self._parser_state.error('missing key') + key = tokens.pop(0) + if key.type != 'IDENTIFIER': + self._parser_state.error('invalid key', token=key) + if not tokens: + self._parser_state.error('missing equal') + if tokens[0].type != 'EQUAL': + self._parser_state.error('invalid equal', token=tokens[0]) + tokens.pop(0) + value_list = [] + while tokens: + value = tokens.pop(0) + if value.type != 'IDENTIFIER': + self._parser_state.error('invalid value', token=value) + value_list.append(value.value) + if not tokens or tokens[0].type != 'COMMA': + break + tokens.pop(0) + else: + self._parser_state.error('empty value') + return (key.value, value_list) + + # metadata = '[' , key-value-pair , [ { ';' , key-value-pair } ] , ']' + # ; + def _parse_metadata(self, tokens): + if not tokens: + self._parser_state.error('missing opening bracket') + opening_bracket = tokens.pop(0) + if opening_bracket.type != 'LBRACKET': + self._parser_state.error( + 'invalid opening bracket', token=opening_bracket) + metadata = {} + while tokens: + first_token = tokens[0] + key, value = self._parse_key_value_pair(tokens) + if key in metadata: + self._parser_state.error( + 'duplicate metadata key: "%s"' % key, token=first_token) + metadata[key] = value + if not tokens or tokens[0].type != 'SEMICOLON': + break + tokens.pop(0) + if not tokens or tokens[0].type != 'RBRACKET': + self._parser_state.error('unclosed bracket', token=opening_bracket) + tokens.pop(0) + return metadata + + # syscall-descriptor = syscall-name , [ metadata ] + # | libc-function , [ metadata ] + # ; + def _parse_syscall_descriptor(self, tokens): + if not tokens: + self._parser_state.error('missing syscall descriptor') + syscall_descriptor = tokens.pop(0) + if syscall_descriptor.type != 'IDENTIFIER': + self._parser_state.error( + 'invalid syscall descriptor', token=syscall_descriptor) + # TODO(lhchavez): Support libc function names. + if tokens and tokens[0].type == 'LBRACKET': + metadata = self._parse_metadata(tokens) + if 'arch' in metadata and self._arch.arch_name not in metadata['arch']: + return () + if syscall_descriptor.value not in self._arch.syscalls: + self._parser_state.error( + 'nonexistent syscall', token=syscall_descriptor) + return (Syscall(syscall_descriptor.value, + self._arch.syscalls[syscall_descriptor.value]), ) + + # filter-statement = '{' , syscall-descriptor , [ { ',', syscall-descriptor } ] , '}' , + # ':' , filter + # | syscall-descriptor , ':' , filter + # ; + def parse_filter_statement(self, tokens): + """Parse a filter statement and return a ParsedFilterStatement.""" + if not tokens: + self._parser_state.error('empty filter statement') + syscall_descriptors = [] + if tokens[0].type == 'LBRACE': + opening_brace = tokens.pop(0) + while tokens: + syscall_descriptors.extend( + self._parse_syscall_descriptor(tokens)) + if not tokens or tokens[0].type != 'COMMA': + break + tokens.pop(0) + if not tokens or tokens[0].type != 'RBRACE': + self._parser_state.error('unclosed brace', token=opening_brace) + tokens.pop(0) + else: + syscall_descriptors.extend(self._parse_syscall_descriptor(tokens)) + if not tokens: + self._parser_state.error('missing colon') + if tokens[0].type != 'COLON': + self._parser_state.error('invalid colon', token=tokens[0]) + tokens.pop(0) + parsed_filter = self.parse_filter(tokens) + if not syscall_descriptors: + return None + return ParsedFilterStatement(tuple(syscall_descriptors), parsed_filter) + + # include-statement = '@include' , posix-path + # ; + def _parse_include_statement(self, tokens): + if not tokens: + self._parser_state.error('empty filter statement') + if tokens[0].type != 'INCLUDE': + self._parser_state.error('invalid include', token=tokens[0]) + tokens.pop(0) + if not tokens: + self._parser_state.error('empty include path') + include_path = tokens.pop(0) + if include_path.type != 'PATH': + self._parser_state.error( + 'invalid include path', token=include_path) + if len(self._parser_states) == self._include_depth_limit: + self._parser_state.error('@include statement nested too deep') + include_filename = os.path.normpath( + os.path.join( + os.path.dirname(self._parser_state.filename), + include_path.value)) + if not os.path.isfile(include_filename): + self._parser_state.error( + 'Could not @include %s' % include_filename, token=include_path) + return self._parse_policy_file(include_filename) + + def _parse_frequency_file(self, filename): + self._parser_states.append(ParserState(filename)) + try: + frequency_mapping = collections.defaultdict(int) + with open(filename) as frequency_file: + for line in frequency_file: + self._parser_state.set_line(line.rstrip()) + tokens = self._parser_state.tokenize() + + if not tokens: + continue + + syscall_numbers = self._parse_syscall_descriptor(tokens) + if not tokens: + self._parser_state.error('missing colon') + if tokens[0].type != 'COLON': + self._parser_state.error( + 'invalid colon', token=tokens[0]) + tokens.pop(0) + + if not tokens: + self._parser_state.error('missing number') + number = tokens.pop(0) + if number.type != 'NUMERIC_CONSTANT': + self._parser_state.error( + 'invalid number', token=number) + number_value = int(number.value, base=0) + if number_value < 0: + self._parser_state.error( + 'invalid number', token=number) + + for syscall_number in syscall_numbers: + frequency_mapping[syscall_number] += number_value + return frequency_mapping + finally: + self._parser_states.pop() + + # frequency-statement = '@frequency' , posix-path + # ; + def _parse_frequency_statement(self, tokens): + if not tokens: + self._parser_state.error('empty frequency statement') + if tokens[0].type != 'FREQUENCY': + self._parser_state.error('invalid frequency', token=tokens[0]) + tokens.pop(0) + if not tokens: + self._parser_state.error('empty frequency path') + frequency_path = tokens.pop(0) + if frequency_path.type != 'PATH': + self._parser_state.error( + 'invalid frequency path', token=frequency_path) + frequency_filename = os.path.normpath( + os.path.join( + os.path.dirname(self._parser_state.filename), + frequency_path.value)) + if not os.path.isfile(frequency_filename): + self._parser_state.error( + 'Could not open frequency file %s' % frequency_filename, + token=frequency_path) + return self._parse_frequency_file(frequency_filename) + + # default-statement = '@default' , default-action + # ; + def _parse_default_statement(self, tokens): + if not tokens: + self._parser_state.error('empty default statement') + if tokens[0].type != 'DEFAULT': + self._parser_state.error('invalid default', token=tokens[0]) + tokens.pop(0) + if not tokens: + self._parser_state.error('empty action') + return self._parse_default_action(tokens) + + def _parse_policy_file(self, filename): + self._parser_states.append(ParserState(filename)) + try: + statements = [] + with open(filename) as policy_file: + for line in policy_file: + self._parser_state.set_line(line.rstrip()) + tokens = self._parser_state.tokenize() + + if not tokens: + # Allow empty lines. + continue + + if tokens[0].type == 'INCLUDE': + statements.extend( + self._parse_include_statement(tokens)) + elif tokens[0].type == 'FREQUENCY': + for syscall_number, frequency in self._parse_frequency_statement( + tokens).items(): + self._frequency_mapping[ + syscall_number] += frequency + elif tokens[0].type == 'DEFAULT': + self._default_action = self._parse_default_statement( + tokens) + else: + statement = self.parse_filter_statement(tokens) + if statement is None: + # If all the syscalls in the statement are for + # another arch, skip the whole statement. + continue + statements.append(statement) + + if tokens: + self._parser_state.error( + 'extra tokens', token=tokens[0]) + return statements + finally: + self._parser_states.pop() + + def parse_file(self, filename): + """Parse a file and return the list of FilterStatements.""" + self._frequency_mapping = collections.defaultdict(int) + try: + statements = [x for x in self._parse_policy_file(filename)] + except RecursionError: + raise ParseException('recursion limit exceeded', filename, + self._parser_states[-1].line) + + # Collapse statements into a single syscall-to-filter-list. + syscall_filter_mapping = {} + filter_statements = [] + for syscalls, filters in statements: + for syscall in syscalls: + if syscall not in syscall_filter_mapping: + filter_statements.append( + FilterStatement( + syscall, self._frequency_mapping.get(syscall, 1), + [])) + syscall_filter_mapping[syscall] = filter_statements[-1] + syscall_filter_mapping[syscall].filters.extend(filters) + for filter_statement in filter_statements: + unconditional_actions_suffix = list( + itertools.dropwhile(lambda filt: filt.expression is not None, + filter_statement.filters)) + if len(unconditional_actions_suffix) == 1: + # The last filter already has an unconditional action, no need + # to add another one. + continue + if len(unconditional_actions_suffix) > 1: + raise ParseException(('Syscall %s (number %d) already had ' + 'an unconditional action applied') % + (filter_statement.syscall.name, + filter_statement.syscall.number), + filename, self._parser_states[-1].line) + assert not unconditional_actions_suffix + filter_statement.filters.append( + Filter(expression=None, action=self._default_action)) + return ParsedPolicy(self._default_action, filter_statements) diff --git a/tools/parser_unittest.py b/tools/parser_unittest.py index d40ab42..4fba590 100755 --- a/tools/parser_unittest.py +++ b/tools/parser_unittest.py @@ -21,9 +21,12 @@ from __future__ import division from __future__ import print_function import os +import shutil +import tempfile import unittest import arch +import bpf import parser # pylint: disable=wrong-import-order ARCH_64 = arch.Arch.load_from_json( @@ -97,7 +100,8 @@ class ParseConstantTests(unittest.TestCase): def setUp(self): self.arch = ARCH_64 - self.parser = parser.PolicyParser(self.arch) + self.parser = parser.PolicyParser( + self.arch, kill_action=bpf.KillProcess()) def _tokenize(self, line): # pylint: disable=protected-access @@ -233,5 +237,401 @@ class ParseConstantTests(unittest.TestCase): self.parser.parse_value(self._tokenize('0|')) +class ParseFilterExpressionTests(unittest.TestCase): + """Tests for PolicyParser.parse_argument_expression.""" + + def setUp(self): + self.arch = ARCH_64 + self.parser = parser.PolicyParser( + self.arch, kill_action=bpf.KillProcess()) + + def _tokenize(self, line): + # pylint: disable=protected-access + self.parser._parser_state.set_line(line) + return self.parser._parser_state.tokenize() + + def test_parse_argument_expression(self): + """Accept valid argument expressions.""" + self.assertEqual( + self.parser.parse_argument_expression( + self._tokenize( + 'arg0 in 0xffff || arg0 == PROT_EXEC && arg1 == PROT_WRITE' + )), [ + [parser.Atom(0, 'in', 0xffff)], + [parser.Atom(0, '==', 4), + parser.Atom(1, '==', 2)], + ]) + + def test_parse_empty_argument_expression(self): + """Reject empty argument expressions.""" + with self.assertRaisesRegex(parser.ParseException, + 'empty argument expression'): + self.parser.parse_argument_expression( + self._tokenize('arg0 in 0xffff ||')) + + def test_parse_empty_clause(self): + """Reject empty clause.""" + with self.assertRaisesRegex(parser.ParseException, 'empty clause'): + self.parser.parse_argument_expression( + self._tokenize('arg0 in 0xffff &&')) + + def test_parse_invalid_argument(self): + """Reject invalid argument.""" + with self.assertRaisesRegex(parser.ParseException, 'invalid argument'): + self.parser.parse_argument_expression( + self._tokenize('argX in 0xffff')) + + def test_parse_invalid_operator(self): + """Reject invalid operator.""" + with self.assertRaisesRegex(parser.ParseException, 'invalid operator'): + self.parser.parse_argument_expression( + self._tokenize('arg0 = 0xffff')) + + +class ParseFilterTests(unittest.TestCase): + """Tests for PolicyParser.parse_filter.""" + + def setUp(self): + self.arch = ARCH_64 + self.parser = parser.PolicyParser( + self.arch, kill_action=bpf.KillProcess()) + + def _tokenize(self, line): + # pylint: disable=protected-access + self.parser._parser_state.set_line(line) + return self.parser._parser_state.tokenize() + + def test_parse_filter(self): + """Accept valid filters.""" + self.assertEqual( + self.parser.parse_filter(self._tokenize('arg0 == 0')), [ + parser.Filter([[parser.Atom(0, '==', 0)]], bpf.Allow()), + ]) + self.assertEqual( + self.parser.parse_filter(self._tokenize('kill-process')), [ + parser.Filter(None, bpf.KillProcess()), + ]) + self.assertEqual( + self.parser.parse_filter(self._tokenize('kill-thread')), [ + parser.Filter(None, bpf.KillThread()), + ]) + self.assertEqual( + self.parser.parse_filter(self._tokenize('trap')), [ + parser.Filter(None, bpf.Trap()), + ]) + self.assertEqual( + self.parser.parse_filter(self._tokenize('return ENOSYS')), [ + parser.Filter(None, + bpf.ReturnErrno(self.arch.constants['ENOSYS'])), + ]) + self.assertEqual( + self.parser.parse_filter(self._tokenize('trace')), [ + parser.Filter(None, bpf.Trace()), + ]) + self.assertEqual( + self.parser.parse_filter(self._tokenize('log')), [ + parser.Filter(None, bpf.Log()), + ]) + self.assertEqual( + self.parser.parse_filter(self._tokenize('allow')), [ + parser.Filter(None, bpf.Allow()), + ]) + self.assertEqual( + self.parser.parse_filter(self._tokenize('1')), [ + parser.Filter(None, bpf.Allow()), + ]) + self.assertEqual( + self.parser.parse_filter( + self._tokenize( + '{ arg0 == 0, arg0 == 1; return ENOSYS, trap }')), + [ + parser.Filter([[parser.Atom(0, '==', 0)]], bpf.Allow()), + parser.Filter([[parser.Atom(0, '==', 1)]], + bpf.ReturnErrno(self.arch.constants['ENOSYS'])), + parser.Filter(None, bpf.Trap()), + ]) + + def test_parse_missing_return_value(self): + """Reject missing return value.""" + with self.assertRaisesRegex(parser.ParseException, + 'missing return value'): + self.parser.parse_filter(self._tokenize('return')) + + def test_parse_invalid_return_value(self): + """Reject invalid return value.""" + with self.assertRaisesRegex(parser.ParseException, 'invalid constant'): + self.parser.parse_filter(self._tokenize('return arg0')) + + def test_parse_unclosed_brace(self): + """Reject unclosed brace.""" + with self.assertRaisesRegex(parser.ParseException, 'unclosed brace'): + self.parser.parse_filter(self._tokenize('{ allow')) + + +class ParseFilterStatementTests(unittest.TestCase): + """Tests for PolicyParser.parse_filter_statement.""" + + def setUp(self): + self.arch = ARCH_64 + self.parser = parser.PolicyParser( + self.arch, kill_action=bpf.KillProcess()) + + def _tokenize(self, line): + # pylint: disable=protected-access + self.parser._parser_state.set_line(line) + return self.parser._parser_state.tokenize() + + def test_parse_filter_statement(self): + """Accept valid filter statements.""" + self.assertEqual( + self.parser.parse_filter_statement( + self._tokenize('read: arg0 == 0')), + parser.ParsedFilterStatement((parser.Syscall('read', 0), ), [ + parser.Filter([[parser.Atom(0, '==', 0)]], bpf.Allow()), + ])) + self.assertEqual( + self.parser.parse_filter_statement( + self._tokenize('{read, write}: arg0 == 0')), + parser.ParsedFilterStatement(( + parser.Syscall('read', 0), + parser.Syscall('write', 1), + ), [ + parser.Filter([[parser.Atom(0, '==', 0)]], bpf.Allow()), + ])) + + def test_parse_metadata(self): + """Accept valid filter statements with metadata.""" + self.assertEqual( + self.parser.parse_filter_statement( + self._tokenize('read[arch=test]: arg0 == 0')), + parser.ParsedFilterStatement((parser.Syscall('read', 0), ), [ + parser.Filter([[parser.Atom(0, '==', 0)]], bpf.Allow()), + ])) + self.assertEqual( + self.parser.parse_filter_statement( + self._tokenize( + '{read, nonexistent[arch=nonexistent]}: arg0 == 0')), + parser.ParsedFilterStatement((parser.Syscall('read', 0), ), [ + parser.Filter([[parser.Atom(0, '==', 0)]], bpf.Allow()), + ])) + + def test_parse_unclosed_brace(self): + """Reject unclosed brace.""" + with self.assertRaisesRegex(parser.ParseException, 'unclosed brace'): + self.parser.parse_filter_statement( + self._tokenize('{ read, write: arg0 == 0')) + + def test_parse_missing_colon(self): + """Reject missing colon.""" + with self.assertRaisesRegex(parser.ParseException, 'missing colon'): + self.parser.parse_filter_statement(self._tokenize('read')) + + def test_parse_invalid_colon(self): + """Reject invalid colon.""" + with self.assertRaisesRegex(parser.ParseException, 'invalid colon'): + self.parser.parse_filter_statement(self._tokenize('read arg0')) + + def test_parse_missing_filter(self): + """Reject missing filter.""" + with self.assertRaisesRegex(parser.ParseException, 'missing filter'): + self.parser.parse_filter_statement(self._tokenize('read:')) + + +class ParseFileTests(unittest.TestCase): + """Tests for PolicyParser.parse_file.""" + + def setUp(self): + self.arch = ARCH_64 + self.parser = parser.PolicyParser( + self.arch, kill_action=bpf.KillProcess()) + self.tempdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def _write_file(self, filename, contents): + """Helper to write out a file for testing.""" + path = os.path.join(self.tempdir, filename) + with open(path, 'w') as outf: + outf.write(contents) + return path + + def test_parse_simple(self): + """Allow simple policy files.""" + path = self._write_file( + 'test.policy', """ + # Comment. + read: allow + write: allow + """) + + self.assertEqual( + self.parser.parse_file(path), + parser.ParsedPolicy( + default_action=bpf.KillProcess(), + filter_statements=[ + parser.FilterStatement( + syscall=parser.Syscall('read', 0), + frequency=1, + filters=[ + parser.Filter(None, bpf.Allow()), + ]), + parser.FilterStatement( + syscall=parser.Syscall('write', 1), + frequency=1, + filters=[ + parser.Filter(None, bpf.Allow()), + ]), + ])) + + def test_parse_default(self): + """Allow defining a default action.""" + path = self._write_file( + 'test.policy', """ + @default kill-thread + read: allow + """) + + self.assertEqual( + self.parser.parse_file(path), + parser.ParsedPolicy( + default_action=bpf.KillThread(), + filter_statements=[ + parser.FilterStatement( + syscall=parser.Syscall('read', 0), + frequency=1, + filters=[ + parser.Filter(None, bpf.Allow()), + ]), + ])) + + def test_parse_default_permissive(self): + """Reject defining a permissive default action.""" + path = self._write_file( + 'test.policy', """ + @default log + read: allow + """) + + with self.assertRaisesRegex(parser.ParseException, + r'invalid permissive default action'): + self.parser.parse_file(path) + + def test_parse_simple_grouped(self): + """Allow simple policy files.""" + path = self._write_file( + 'test.policy', """ + # Comment. + {read, write}: allow + """) + + self.assertEqual( + self.parser.parse_file(path), + parser.ParsedPolicy( + default_action=bpf.KillProcess(), + filter_statements=[ + parser.FilterStatement( + syscall=parser.Syscall('read', 0), + frequency=1, + filters=[ + parser.Filter(None, bpf.Allow()), + ]), + parser.FilterStatement( + syscall=parser.Syscall('write', 1), + frequency=1, + filters=[ + parser.Filter(None, bpf.Allow()), + ]), + ])) + + def test_parse_include(self): + """Allow including policy files.""" + path = self._write_file( + 'test.include.policy', """ + {read, write}: arg0 == 0; allow + """) + path = self._write_file( + 'test.policy', """ + @include ./test.include.policy + read: return ENOSYS + """) + + self.assertEqual( + self.parser.parse_file(path), + parser.ParsedPolicy( + default_action=bpf.KillProcess(), + filter_statements=[ + parser.FilterStatement( + syscall=parser.Syscall('read', 0), + frequency=1, + filters=[ + parser.Filter([[parser.Atom(0, '==', 0)]], + bpf.Allow()), + parser.Filter( + None, + bpf.ReturnErrno( + self.arch.constants['ENOSYS'])), + ]), + parser.FilterStatement( + syscall=parser.Syscall('write', 1), + frequency=1, + filters=[ + parser.Filter([[parser.Atom(0, '==', 0)]], + bpf.Allow()), + parser.Filter(None, bpf.KillProcess()), + ]), + ])) + + def test_parse_frequency(self): + """Allow including frequency files.""" + self._write_file( + 'test.frequency', """ + read: 2 + write: 3 + """) + path = self._write_file( + 'test.policy', """ + @frequency ./test.frequency + read: allow + """) + + self.assertEqual( + self.parser.parse_file(path), + parser.ParsedPolicy( + default_action=bpf.KillProcess(), + filter_statements=[ + parser.FilterStatement( + syscall=parser.Syscall('read', 0), + frequency=2, + filters=[ + parser.Filter(None, bpf.Allow()), + ]), + ])) + + def test_parse_multiple_unconditional(self): + """Reject actions after an unconditional action.""" + path = self._write_file( + 'test.policy', """ + read: allow + read: allow + """) + + with self.assertRaisesRegex( + parser.ParseException, + r'Syscall read.*already had an unconditional action applied'): + self.parser.parse_file(path) + + path = self._write_file( + 'test.policy', """ + read: log + read: arg0 == 0; log + """) + + with self.assertRaisesRegex( + parser.ParseException, + r'Syscall read.*already had an unconditional action applied'): + self.parser.parse_file(path) + + if __name__ == '__main__': unittest.main() diff --git a/tools/testdata/arch_64.json b/tools/testdata/arch_64.json index c23f988..bd3e2f4 100644 --- a/tools/testdata/arch_64.json +++ b/tools/testdata/arch_64.json @@ -4,9 +4,12 @@ "bits": 64, "syscalls": { "read": 0, - "write": 1 + "write": 1, + "open": 2, + "close": 3 }, "constants": { + "ENOSYS": 38, "O_RDONLY": 0, "PROT_WRITE": 2, "PROT_EXEC": 4 |