aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBill Rassieur <rassb@google.com>2019-03-29 04:15:40 +0000
committerBill Rassieur <rassb@google.com>2019-03-29 04:15:40 +0000
commit3cd28db90a6421187338e6c44b4b884af8790dd6 (patch)
tree339b9122dc1427f93693d451c5c494eb67fb85de
parent788d5b1f188aeb6131746dbe88af497424102b8b (diff)
parenta223091ff4a7cf746dd16312c483241ef602a48c (diff)
downloadminijail-3cd28db90a6421187338e6c44b4b884af8790dd6.tar.gz
Merge master@5406228 into git_qt-dev-plus-aosp.
Change-Id: I5dc2e5fda83fd1350b0383b373f22cb1dc0d198c BUG: 129345239
-rw-r--r--minijail0.52
-rw-r--r--syscall_filter.c41
-rw-r--r--syscall_filter_unittest.cc35
-rw-r--r--tools/bpf.py653
-rw-r--r--tools/compiler.py94
-rwxr-xr-xtools/compiler_unittest.py276
-rw-r--r--tools/parser.py483
-rwxr-xr-xtools/parser_unittest.py402
-rw-r--r--tools/testdata/arch_64.json5
9 files changed, 1987 insertions, 4 deletions
diff --git a/minijail0.5 b/minijail0.5
index bf6d0d7..f7b411a 100644
--- a/minijail0.5
+++ b/minijail0.5
@@ -34,6 +34,8 @@ The policy file supplied to the \fB-S\fR argument supports the following syntax:
\fB<empty line>\fR
\fB# any single line comment\fR
+Long lines may be broken up using \\ at the end.
+
A policy that emulates \fBseccomp\fR(2) in mode 1 may look like:
read: 1
write: 1
diff --git a/syscall_filter.c b/syscall_filter.c
index c1526a4..049797a 100644
--- a/syscall_filter.c
+++ b/syscall_filter.c
@@ -499,6 +499,45 @@ int parse_include_statement(struct parser_state *state, char *policy_line,
return 0;
}
+/*
+ * This is like getline() but supports line wrapping with \.
+ */
+static ssize_t getmultiline(char **lineptr, size_t *n, FILE *stream)
+{
+ ssize_t ret = getline(lineptr, n, stream);
+ if (ret < 0)
+ return ret;
+
+ char *line = *lineptr;
+ /* Eat the newline to make processing below easier. */
+ if (ret > 0 && line[ret - 1] == '\n')
+ line[--ret] = '\0';
+
+ /* If the line doesn't end in a backslash, we're done. */
+ if (ret <= 0 || line[ret - 1] != '\\')
+ return ret;
+
+ /* This line ends in a backslash. Get the nextline. */
+ line[--ret] = '\0';
+ size_t next_n = 0;
+ char *next_line = NULL;
+ ssize_t next_ret = getmultiline(&next_line, &next_n, stream);
+ if (next_ret == -1) {
+ free(next_line);
+ /* We couldn't fully read the line, so return an error. */
+ return -1;
+ }
+
+ /* Merge the lines. */
+ *n = ret + next_ret + 2;
+ line = realloc(line, *n);
+ line[ret] = ' ';
+ memcpy(&line[ret + 1], next_line, next_ret + 1);
+ free(next_line);
+ *lineptr = line;
+ return ret;
+}
+
int compile_file(const char *filename, FILE *policy_file,
struct filter_block *head, struct filter_block **arg_blocks,
struct bpf_labels *labels, int use_ret_trap, int allow_logging,
@@ -522,7 +561,7 @@ int compile_file(const char *filename, FILE *policy_file,
size_t len = 0;
int ret = 0;
- while (getline(&line, &len, policy_file) != -1) {
+ while (getmultiline(&line, &len, policy_file) != -1) {
char *policy_line = line;
policy_line = strip(policy_line);
diff --git a/syscall_filter_unittest.cc b/syscall_filter_unittest.cc
index 8a0b19a..aca5f54 100644
--- a/syscall_filter_unittest.cc
+++ b/syscall_filter_unittest.cc
@@ -1261,6 +1261,41 @@ TEST_F(FileTest, seccomp_read) {
EXPECT_EQ(curr_block->next, nullptr);
}
+TEST_F(FileTest, multiline) {
+ std::string policy =
+ "read:\\\n1\n"
+ "openat:arg0 in\\\n5";
+
+ const int LABEL_ID = 0;
+
+ FILE* policy_file = write_policy_to_pipe(policy);
+ ASSERT_NE(policy_file, nullptr);
+ int res = test_compile_file("policy", policy_file, head_, &arg_blocks_,
+ &labels_);
+ fclose(policy_file);
+
+ /*
+ * Policy should be valid.
+ */
+ ASSERT_EQ(res, 0);
+
+ /* First block is the read. */
+ struct filter_block *curr_block = head_;
+ ASSERT_NE(curr_block, nullptr);
+ EXPECT_ALLOW_SYSCALL(curr_block->instrs, __NR_read);
+
+ /* Second block is the open. */
+ curr_block = curr_block->next;
+ ASSERT_NE(curr_block, nullptr);
+ EXPECT_ALLOW_SYSCALL_ARGS(curr_block->instrs,
+ __NR_openat,
+ LABEL_ID,
+ JUMP_JT,
+ JUMP_JF);
+
+ EXPECT_EQ(curr_block->next, nullptr);
+}
+
TEST(FilterTest, seccomp_mode1) {
struct sock_fprog actual;
std::string policy =
diff --git a/tools/bpf.py b/tools/bpf.py
new file mode 100644
index 0000000..e89e93f
--- /dev/null
+++ b/tools/bpf.py
@@ -0,0 +1,653 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+#
+# Copyright (C) 2018 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tools to interact with BPF programs."""
+
+import abc
+import collections
+import struct
+
+# This comes from syscall(2). Most architectures only support passing 6 args to
+# syscalls, but ARM supports passing 7.
+MAX_SYSCALL_ARGUMENTS = 7
+
+# The following fields were copied from <linux/bpf_common.h>:
+
+# Instruction classes
+BPF_LD = 0x00
+BPF_LDX = 0x01
+BPF_ST = 0x02
+BPF_STX = 0x03
+BPF_ALU = 0x04
+BPF_JMP = 0x05
+BPF_RET = 0x06
+BPF_MISC = 0x07
+
+# LD/LDX fields.
+# Size
+BPF_W = 0x00
+BPF_H = 0x08
+BPF_B = 0x10
+# Mode
+BPF_IMM = 0x00
+BPF_ABS = 0x20
+BPF_IND = 0x40
+BPF_MEM = 0x60
+BPF_LEN = 0x80
+BPF_MSH = 0xa0
+
+# JMP fields.
+BPF_JA = 0x00
+BPF_JEQ = 0x10
+BPF_JGT = 0x20
+BPF_JGE = 0x30
+BPF_JSET = 0x40
+
+# Source
+BPF_K = 0x00
+BPF_X = 0x08
+
+BPF_MAXINSNS = 4096
+
+# The following fields were copied from <linux/seccomp.h>:
+
+SECCOMP_RET_KILL_PROCESS = 0x80000000
+SECCOMP_RET_KILL_THREAD = 0x00000000
+SECCOMP_RET_TRAP = 0x00030000
+SECCOMP_RET_ERRNO = 0x00050000
+SECCOMP_RET_TRACE = 0x7ff00000
+SECCOMP_RET_LOG = 0x7ffc0000
+SECCOMP_RET_ALLOW = 0x7fff0000
+
+SECCOMP_RET_ACTION_FULL = 0xffff0000
+SECCOMP_RET_DATA = 0x0000ffff
+
+
+def arg_offset(arg_index, hi=False):
+ """Return the BPF_LD|BPF_W|BPF_ABS addressing-friendly register offset."""
+ offsetof_args = 4 + 4 + 8
+ arg_width = 8
+ return offsetof_args + arg_width * arg_index + (arg_width // 2) * hi
+
+
+def simulate(instructions, arch, syscall_number, *args):
+ """Simulate a BPF program with the given arguments."""
+ args = ((args + (0, ) *
+ (MAX_SYSCALL_ARGUMENTS - len(args)))[:MAX_SYSCALL_ARGUMENTS])
+ input_memory = struct.pack('IIQ' + 'Q' * MAX_SYSCALL_ARGUMENTS,
+ syscall_number, arch, 0, *args)
+
+ register = 0
+ program_counter = 0
+ cost = 0
+ while program_counter < len(instructions):
+ ins = instructions[program_counter]
+ program_counter += 1
+ cost += 1
+ if ins.code == BPF_LD | BPF_W | BPF_ABS:
+ register = struct.unpack('I', input_memory[ins.k:ins.k + 4])[0]
+ elif ins.code == BPF_JMP | BPF_JA | BPF_K:
+ program_counter += ins.k
+ elif ins.code == BPF_JMP | BPF_JEQ | BPF_K:
+ if register == ins.k:
+ program_counter += ins.jt
+ else:
+ program_counter += ins.jf
+ elif ins.code == BPF_JMP | BPF_JGT | BPF_K:
+ if register > ins.k:
+ program_counter += ins.jt
+ else:
+ program_counter += ins.jf
+ elif ins.code == BPF_JMP | BPF_JGE | BPF_K:
+ if register >= ins.k:
+ program_counter += ins.jt
+ else:
+ program_counter += ins.jf
+ elif ins.code == BPF_JMP | BPF_JSET | BPF_K:
+ if register & ins.k != 0:
+ program_counter += ins.jt
+ else:
+ program_counter += ins.jf
+ elif ins.code == BPF_RET:
+ if ins.k == SECCOMP_RET_KILL_PROCESS:
+ return (cost, 'KILL_PROCESS')
+ if ins.k == SECCOMP_RET_KILL_THREAD:
+ return (cost, 'KILL_THREAD')
+ if ins.k == SECCOMP_RET_TRAP:
+ return (cost, 'TRAP')
+ if (ins.k & SECCOMP_RET_ACTION_FULL) == SECCOMP_RET_ERRNO:
+ return (cost, 'ERRNO', ins.k & SECCOMP_RET_DATA)
+ if ins.k == SECCOMP_RET_TRACE:
+ return (cost, 'TRACE')
+ if ins.k == SECCOMP_RET_LOG:
+ return (cost, 'LOG')
+ if ins.k == SECCOMP_RET_ALLOW:
+ return (cost, 'ALLOW')
+ raise Exception('unknown return %#x' % ins.k)
+ else:
+ raise Exception('unknown instruction %r' % (ins, ))
+ raise Exception('out-of-bounds')
+
+
+class SockFilter(
+ collections.namedtuple('SockFilter', ['code', 'jt', 'jf', 'k'])):
+ """A representation of struct sock_filter."""
+
+ __slots__ = ()
+
+ def encode(self):
+ """Return an encoded version of the SockFilter."""
+ return struct.pack('HBBI', self.code, self.jt, self.jf, self.k)
+
+
+class AbstractBlock(abc.ABC):
+ """A class that implements the visitor pattern."""
+
+ def __init__(self):
+ super().__init__()
+
+ @abc.abstractmethod
+ def accept(self, visitor):
+ pass
+
+
+class BasicBlock(AbstractBlock):
+ """A concrete implementation of AbstractBlock that has been compiled."""
+
+ def __init__(self, instructions):
+ super().__init__()
+ self._instructions = instructions
+
+ def accept(self, visitor):
+ visitor.visit(self)
+
+ @property
+ def instructions(self):
+ return self._instructions
+
+ @property
+ def opcodes(self):
+ return b''.join(i.encode() for i in self._instructions)
+
+ def __eq__(self, o):
+ if not isinstance(o, BasicBlock):
+ return False
+ return self._instructions == o._instructions
+
+
+class KillProcess(BasicBlock):
+ """A BasicBlock that unconditionally returns KILL_PROCESS."""
+
+ def __init__(self):
+ super().__init__(
+ [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_PROCESS)])
+
+
+class KillThread(BasicBlock):
+ """A BasicBlock that unconditionally returns KILL_THREAD."""
+
+ def __init__(self):
+ super().__init__(
+ [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_THREAD)])
+
+
+class Trap(BasicBlock):
+ """A BasicBlock that unconditionally returns TRAP."""
+
+ def __init__(self):
+ super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRAP)])
+
+
+class Trace(BasicBlock):
+ """A BasicBlock that unconditionally returns TRACE."""
+
+ def __init__(self):
+ super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRACE)])
+
+
+class Log(BasicBlock):
+ """A BasicBlock that unconditionally returns LOG."""
+
+ def __init__(self):
+ super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_LOG)])
+
+
+class ReturnErrno(BasicBlock):
+ """A BasicBlock that unconditionally returns the specified errno."""
+
+ def __init__(self, errno):
+ super().__init__([
+ SockFilter(BPF_RET, 0x00, 0x00,
+ SECCOMP_RET_ERRNO | (errno & SECCOMP_RET_DATA))
+ ])
+ self.errno = errno
+
+
+class Allow(BasicBlock):
+ """A BasicBlock that unconditionally returns ALLOW."""
+
+ def __init__(self):
+ super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_ALLOW)])
+
+
+class ValidateArch(AbstractBlock):
+ """An AbstractBlock that validates the architecture."""
+
+ def __init__(self, next_block):
+ super().__init__()
+ self.next_block = next_block
+
+ def accept(self, visitor):
+ self.next_block.accept(visitor)
+ visitor.visit(self)
+
+
+class SyscallEntry(AbstractBlock):
+ """An abstract block that represents a syscall comparison in a DAG."""
+
+ def __init__(self, syscall_number, jt, jf, *, op=BPF_JEQ):
+ super().__init__()
+ self.op = op
+ self.syscall_number = syscall_number
+ self.jt = jt
+ self.jf = jf
+
+ def __lt__(self, o):
+ # Defined because we want to compare tuples that contain SyscallEntries.
+ return False
+
+ def __gt__(self, o):
+ # Defined because we want to compare tuples that contain SyscallEntries.
+ return False
+
+ def accept(self, visitor):
+ self.jt.accept(visitor)
+ self.jf.accept(visitor)
+ visitor.visit(self)
+
+ def __lt__(self, o):
+ # Defined because we want to compare tuples that contain SyscallEntries.
+ return False
+
+ def __gt__(self, o):
+ # Defined because we want to compare tuples that contain SyscallEntries.
+ return False
+
+
+class WideAtom(AbstractBlock):
+ """A BasicBlock that represents a 32-bit wide atom."""
+
+ def __init__(self, arg_offset, op, value, jt, jf):
+ super().__init__()
+ self.arg_offset = arg_offset
+ self.op = op
+ self.value = value
+ self.jt = jt
+ self.jf = jf
+
+ def accept(self, visitor):
+ self.jt.accept(visitor)
+ self.jf.accept(visitor)
+ visitor.visit(self)
+
+
+class Atom(AbstractBlock):
+ """A BasicBlock that represents an atom (a simple comparison operation)."""
+
+ def __init__(self, arg_index, op, value, jt, jf):
+ super().__init__()
+ if op == '==':
+ op = BPF_JEQ
+ elif op == '!=':
+ op = BPF_JEQ
+ jt, jf = jf, jt
+ elif op == '>':
+ op = BPF_JGT
+ elif op == '<=':
+ op = BPF_JGT
+ jt, jf = jf, jt
+ elif op == '>=':
+ op = BPF_JGE
+ elif op == '<':
+ op = BPF_JGE
+ jt, jf = jf, jt
+ elif op == '&':
+ op = BPF_JSET
+ elif op == 'in':
+ op = BPF_JSET
+ # The mask is negated, so the comparison will be true when the
+ # argument includes a flag that wasn't listed in the original
+ # (non-negated) mask. This would be the failure case, so we switch
+ # |jt| and |jf|.
+ value = (~value) & ((1 << 64) - 1)
+ jt, jf = jf, jt
+ else:
+ raise Exception('Unknown operator %s' % op)
+
+ self.arg_index = arg_index
+ self.op = op
+ self.jt = jt
+ self.jf = jf
+ self.value = value
+
+ def accept(self, visitor):
+ self.jt.accept(visitor)
+ self.jf.accept(visitor)
+ visitor.visit(self)
+
+
+class AbstractVisitor(abc.ABC):
+ """An abstract visitor."""
+
+ def process(self, block):
+ block.accept(self)
+ return block
+
+ def visit(self, block):
+ if isinstance(block, KillProcess):
+ self.visitKillProcess(block)
+ elif isinstance(block, KillThread):
+ self.visitKillThread(block)
+ elif isinstance(block, Trap):
+ self.visitTrap(block)
+ elif isinstance(block, ReturnErrno):
+ self.visitReturnErrno(block)
+ elif isinstance(block, Trace):
+ self.visitTrace(block)
+ elif isinstance(block, Log):
+ self.visitLog(block)
+ elif isinstance(block, Allow):
+ self.visitAllow(block)
+ elif isinstance(block, BasicBlock):
+ self.visitBasicBlock(block)
+ elif isinstance(block, ValidateArch):
+ self.visitValidateArch(block)
+ elif isinstance(block, SyscallEntry):
+ self.visitSyscallEntry(block)
+ elif isinstance(block, WideAtom):
+ self.visitWideAtom(block)
+ elif isinstance(block, Atom):
+ self.visitAtom(block)
+ else:
+ raise Exception('Unknown block type: %r' % block)
+
+ @abc.abstractmethod
+ def visitKillProcess(self, block):
+ pass
+
+ @abc.abstractmethod
+ def visitKillThread(self, block):
+ pass
+
+ @abc.abstractmethod
+ def visitTrap(self, block):
+ pass
+
+ @abc.abstractmethod
+ def visitReturnErrno(self, block):
+ pass
+
+ @abc.abstractmethod
+ def visitTrace(self, block):
+ pass
+
+ @abc.abstractmethod
+ def visitLog(self, block):
+ pass
+
+ @abc.abstractmethod
+ def visitAllow(self, block):
+ pass
+
+ @abc.abstractmethod
+ def visitBasicBlock(self, block):
+ pass
+
+ @abc.abstractmethod
+ def visitValidateArch(self, block):
+ pass
+
+ @abc.abstractmethod
+ def visitSyscallEntry(self, block):
+ pass
+
+ @abc.abstractmethod
+ def visitWideAtom(self, block):
+ pass
+
+ @abc.abstractmethod
+ def visitAtom(self, block):
+ pass
+
+
+class CopyingVisitor(AbstractVisitor):
+ """A visitor that copies Blocks."""
+
+ def __init__(self):
+ self._mapping = {}
+
+ def process(self, block):
+ self._mapping = {}
+ block.accept(self)
+ return self._mapping[id(block)]
+
+ def visitKillProcess(self, block):
+ if id(block) in self._mapping:
+ return
+ self._mapping[id(block)] = KillProcess()
+
+ def visitKillThread(self, block):
+ if id(block) in self._mapping:
+ return
+ self._mapping[id(block)] = KillThread()
+
+ def visitTrap(self, block):
+ if id(block) in self._mapping:
+ return
+ self._mapping[id(block)] = Trap()
+
+ def visitReturnErrno(self, block):
+ if id(block) in self._mapping:
+ return
+ self._mapping[id(block)] = ReturnErrno(block.errno)
+
+ def visitTrace(self, block):
+ if id(block) in self._mapping:
+ return
+ self._mapping[id(block)] = Trace()
+
+ def visitLog(self, block):
+ if id(block) in self._mapping:
+ return
+ self._mapping[id(block)] = Log()
+
+ def visitAllow(self, block):
+ if id(block) in self._mapping:
+ return
+ self._mapping[id(block)] = Allow()
+
+ def visitBasicBlock(self, block):
+ if id(block) in self._mapping:
+ return
+ self._mapping[id(block)] = BasicBlock(block.instructions)
+
+ def visitValidateArch(self, block):
+ if id(block) in self._mapping:
+ return
+ self._mapping[id(block)] = ValidateArch(
+ block.arch, self._mapping[id(block.next_block)])
+
+ def visitSyscallEntry(self, block):
+ if id(block) in self._mapping:
+ return
+ self._mapping[id(block)] = SyscallEntry(
+ block.syscall_number,
+ self._mapping[id(block.jt)],
+ self._mapping[id(block.jf)],
+ op=block.op)
+
+ def visitWideAtom(self, block):
+ if id(block) in self._mapping:
+ return
+ self._mapping[id(block)] = WideAtom(
+ block.arg_offset, block.op, block.value, self._mapping[id(
+ block.jt)], self._mapping[id(block.jf)])
+
+ def visitAtom(self, block):
+ if id(block) in self._mapping:
+ return
+ self._mapping[id(block)] = Atom(block.arg_index, block.op, block.value,
+ self._mapping[id(block.jt)],
+ self._mapping[id(block.jf)])
+
+
+class LoweringVisitor(CopyingVisitor):
+ """A visitor that lowers Atoms into WideAtoms."""
+
+ def __init__(self, *, arch):
+ super().__init__()
+ self._bits = arch.bits
+
+ def visitAtom(self, block):
+ if id(block) in self._mapping:
+ return
+
+ lo = block.value & 0xFFFFFFFF
+ hi = (block.value >> 32) & 0xFFFFFFFF
+
+ lo_block = WideAtom(
+ arg_offset(block.arg_index, False), block.op, lo,
+ self._mapping[id(block.jt)], self._mapping[id(block.jf)])
+
+ if self._bits == 32:
+ self._mapping[id(block)] = lo_block
+ return
+
+ if block.op in (BPF_JGE, BPF_JGT):
+ # hi_1,lo_1 <op> hi_2,lo_2
+ #
+ # hi_1 > hi_2 || hi_1 == hi_2 && lo_1 <op> lo_2
+ if hi == 0:
+ # Special case: it's not needed to check whether |hi_1 == hi_2|,
+ # because it's true iff the JGT test fails.
+ self._mapping[id(block)] = WideAtom(
+ arg_offset(block.arg_index, True), BPF_JGT, hi,
+ self._mapping[id(block.jt)], lo_block)
+ return
+ hi_eq_block = WideAtom(
+ arg_offset(block.arg_index, True), BPF_JEQ, hi, lo_block,
+ self._mapping[id(block.jf)])
+ self._mapping[id(block)] = WideAtom(
+ arg_offset(block.arg_index, True), BPF_JGT, hi,
+ self._mapping[id(block.jt)], hi_eq_block)
+ return
+ if block.op == BPF_JSET:
+ # hi_1,lo_1 & hi_2,lo_2
+ #
+ # hi_1 & hi_2 || lo_1 & lo_2
+ if hi == 0:
+ # Special case: |hi_1 & hi_2| will never be True, so jump
+ # directly into the |lo_1 & lo_2| case.
+ self._mapping[id(block)] = lo_block
+ return
+ self._mapping[id(block)] = WideAtom(
+ arg_offset(block.arg_index, True), block.op, hi,
+ self._mapping[id(block.jt)], lo_block)
+ return
+
+ assert block.op == BPF_JEQ, block.op
+
+ # hi_1,lo_1 == hi_2,lo_2
+ #
+ # hi_1 == hi_2 && lo_1 == lo_2
+ self._mapping[id(block)] = WideAtom(
+ arg_offset(block.arg_index, True), block.op, hi, lo_block,
+ self._mapping[id(block.jf)])
+
+
+class FlatteningVisitor:
+ """A visitor that flattens a DAG of Block objects."""
+
+ def __init__(self, *, arch, kill_action):
+ self._kill_action = kill_action
+ self._instructions = []
+ self._arch = arch
+ self._offsets = {}
+
+ @property
+ def result(self):
+ return BasicBlock(self._instructions)
+
+ def _distance(self, block):
+ distance = self._offsets[id(block)] + len(self._instructions)
+ assert distance >= 0
+ return distance
+
+ def _emit_load_arg(self, offset):
+ return [SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, offset)]
+
+ def _emit_jmp(self, op, value, jt_distance, jf_distance):
+ if jt_distance < 0x100 and jf_distance < 0x100:
+ return [
+ SockFilter(BPF_JMP | op | BPF_K, jt_distance, jf_distance,
+ value),
+ ]
+ if jt_distance + 1 < 0x100:
+ return [
+ SockFilter(BPF_JMP | op | BPF_K, jt_distance + 1, 0, value),
+ SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance),
+ ]
+ if jf_distance + 1 < 0x100:
+ return [
+ SockFilter(BPF_JMP | op | BPF_K, 0, jf_distance + 1, value),
+ SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance),
+ ]
+ return [
+ SockFilter(BPF_JMP | op | BPF_K, 0, 1, value),
+ SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance + 1),
+ SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance),
+ ]
+
+ def visit(self, block):
+ if id(block) in self._offsets:
+ return
+
+ if isinstance(block, BasicBlock):
+ instructions = block.instructions
+ elif isinstance(block, ValidateArch):
+ instructions = [
+ SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 4),
+ SockFilter(BPF_JMP | BPF_JEQ | BPF_K,
+ self._distance(block.next_block) + 1, 0,
+ self._arch.arch_nr),
+ ] + self._kill_action.instructions + [
+ SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 0),
+ ]
+ elif isinstance(block, SyscallEntry):
+ instructions = self._emit_jmp(block.op, block.syscall_number,
+ self._distance(block.jt),
+ self._distance(block.jf))
+ elif isinstance(block, WideAtom):
+ instructions = (
+ self._emit_load_arg(block.arg_offset) + self._emit_jmp(
+ block.op, block.value, self._distance(block.jt),
+ self._distance(block.jf)))
+ else:
+ raise Exception('Unknown block type: %r' % block)
+
+ self._instructions = instructions + self._instructions
+ self._offsets[id(block)] = -len(self._instructions)
+ return
diff --git a/tools/compiler.py b/tools/compiler.py
new file mode 100644
index 0000000..96800f1
--- /dev/null
+++ b/tools/compiler.py
@@ -0,0 +1,94 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+#
+# Copyright (C) 2018 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A BPF compiler for the Minijail policy file."""
+
+from __future__ import print_function
+
+import bpf
+import parser # pylint: disable=wrong-import-order
+
+
+class SyscallPolicyEntry:
+ """The parsed version of a seccomp policy line."""
+
+ def __init__(self, name, number, frequency):
+ self.name = name
+ self.number = number
+ self.frequency = frequency
+ self.accumulated = 0
+ self.filter = None
+
+ def __repr__(self):
+ return ('SyscallPolicyEntry<name: %s, number: %d, '
+ 'frequency: %d, filter: %r>') % (self.name, self.number,
+ self.frequency,
+ self.filter.instructions
+ if self.filter else None)
+
+ def simulate(self, arch, syscall_number, *args):
+ """Simulate the policy with the given arguments."""
+ if not self.filter:
+ return (0, 'ALLOW')
+ return bpf.simulate(self.filter.instructions, arch, syscall_number,
+ *args)
+
+
+class PolicyCompiler:
+ """A parser for the Minijail seccomp policy file format."""
+
+ def __init__(self, arch):
+ self._arch = arch
+
+ def compile_filter_statement(self, filter_statement, *, kill_action):
+ """Compile one parser.FilterStatement into BPF."""
+ policy_entry = SyscallPolicyEntry(filter_statement.syscall.name,
+ filter_statement.syscall.number,
+ filter_statement.frequency)
+ # In each step of the way, the false action is the one that is taken if
+ # the immediate boolean condition does not match. This means that the
+ # false action taken here is the one that applies if the whole
+ # expression fails to match.
+ false_action = filter_statement.filters[-1].action
+ if false_action == bpf.Allow():
+ return policy_entry
+ # We then traverse the list of filters backwards since we want
+ # the root of the DAG to be the very first boolean operation in
+ # the filter chain.
+ for filt in filter_statement.filters[:-1][::-1]:
+ for disjunction in filt.expression:
+ # This is the jump target of the very last comparison in the
+ # conjunction. Given that any conjunction that succeeds should
+ # make the whole expression succeed, make the very last
+ # comparison jump to the accept action if it succeeds.
+ true_action = filt.action
+ for atom in disjunction:
+ block = bpf.Atom(atom.argument_index, atom.op, atom.value,
+ true_action, false_action)
+ true_action = block
+ false_action = true_action
+ policy_filter = false_action
+
+ # Lower all Atoms into WideAtoms.
+ lowering_visitor = bpf.LoweringVisitor(arch=self._arch)
+ policy_filter = lowering_visitor.process(policy_filter)
+
+ # Flatten the IR DAG into a single BasicBlock.
+ flattening_visitor = bpf.FlatteningVisitor(
+ arch=self._arch, kill_action=kill_action)
+ policy_filter.accept(flattening_visitor)
+ policy_entry.filter = flattening_visitor.result
+ return policy_entry
diff --git a/tools/compiler_unittest.py b/tools/compiler_unittest.py
new file mode 100755
index 0000000..ba66e62
--- /dev/null
+++ b/tools/compiler_unittest.py
@@ -0,0 +1,276 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+#
+# Copyright (C) 2018 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Unittests for the compiler module."""
+
+from __future__ import print_function
+
+import os
+import tempfile
+import unittest
+
+import arch
+import bpf
+import compiler
+import parser # pylint: disable=wrong-import-order
+
+ARCH_64 = arch.Arch.load_from_json(
+ os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), 'testdata/arch_64.json'))
+
+
+class CompileFilterStatementTests(unittest.TestCase):
+ """Tests for PolicyCompiler.compile_filter_statement."""
+
+ def setUp(self):
+ self.arch = ARCH_64
+ self.compiler = compiler.PolicyCompiler(self.arch)
+
+ def _compile(self, line):
+ with tempfile.NamedTemporaryFile(mode='w') as policy_file:
+ policy_file.write(line)
+ policy_file.flush()
+ policy_parser = parser.PolicyParser(
+ self.arch, kill_action=bpf.KillProcess())
+ parsed_policy = policy_parser.parse_file(policy_file.name)
+ assert len(parsed_policy.filter_statements) == 1
+ return self.compiler.compile_filter_statement(
+ parsed_policy.filter_statements[0],
+ kill_action=bpf.KillProcess())
+
+ def test_allow(self):
+ """Accept lines where the syscall is accepted unconditionally."""
+ block = self._compile('read: allow')
+ self.assertEqual(block.filter, None)
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 0)[1], 'ALLOW')
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 1)[1], 'ALLOW')
+
+ def test_arg0_eq_generated_code(self):
+ """Accept lines with an argument filter with ==."""
+ block = self._compile('read: arg0 == 0x100')
+ # It might be a bit brittle to check the generated code in each test
+ # case instead of just the behavior, but there should be at least one
+ # test where this happens.
+ self.assertEqual(
+ block.filter.instructions,
+ [
+ bpf.SockFilter(bpf.BPF_LD | bpf.BPF_W | bpf.BPF_ABS, 0, 0,
+ bpf.arg_offset(0, True)),
+ # Jump to KILL_PROCESS if the high word does not match.
+ bpf.SockFilter(bpf.BPF_JMP | bpf.BPF_JEQ | bpf.BPF_K, 0, 2, 0),
+ bpf.SockFilter(bpf.BPF_LD | bpf.BPF_W | bpf.BPF_ABS, 0, 0,
+ bpf.arg_offset(0, False)),
+ # Jump to KILL_PROCESS if the low word does not match.
+ bpf.SockFilter(bpf.BPF_JMP | bpf.BPF_JEQ | bpf.BPF_K, 1, 0,
+ 0x100),
+ bpf.SockFilter(bpf.BPF_RET, 0, 0,
+ bpf.SECCOMP_RET_KILL_PROCESS),
+ bpf.SockFilter(bpf.BPF_RET, 0, 0, bpf.SECCOMP_RET_ALLOW),
+ ])
+
+ def test_arg0_comparison_operators(self):
+ """Accept lines with an argument filter with comparison operators."""
+ biases = (-1, 0, 1)
+ # For each operator, store the expectations of simulating the program
+ # against the constant plus each entry from the |biases| array.
+ cases = (
+ ('==', ('KILL_PROCESS', 'ALLOW', 'KILL_PROCESS')),
+ ('!=', ('ALLOW', 'KILL_PROCESS', 'ALLOW')),
+ ('<', ('ALLOW', 'KILL_PROCESS', 'KILL_PROCESS')),
+ ('<=', ('ALLOW', 'ALLOW', 'KILL_PROCESS')),
+ ('>', ('KILL_PROCESS', 'KILL_PROCESS', 'ALLOW')),
+ ('>=', ('KILL_PROCESS', 'ALLOW', 'ALLOW')),
+ )
+ for operator, expectations in cases:
+ block = self._compile('read: arg0 %s 0x100' % operator)
+
+ # Check the filter's behavior.
+ for bias, expectation in zip(biases, expectations):
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr,
+ self.arch.syscalls['read'],
+ 0x100 + bias)[1], expectation)
+
+ def test_arg0_mask_operator(self):
+ """Accept lines with an argument filter with &."""
+ block = self._compile('read: arg0 & 0x3')
+
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 0)[1], 'KILL_PROCESS')
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 1)[1], 'ALLOW')
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 2)[1], 'ALLOW')
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 3)[1], 'ALLOW')
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 4)[1], 'KILL_PROCESS')
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 5)[1], 'ALLOW')
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 6)[1], 'ALLOW')
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 7)[1], 'ALLOW')
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 8)[1], 'KILL_PROCESS')
+
+ def test_arg0_in_operator(self):
+ """Accept lines with an argument filter with in."""
+ block = self._compile('read: arg0 in 0x3')
+
+ # The 'in' operator only ensures that no bits outside the mask are set,
+ # which means that 0 is always allowed.
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 0)[1], 'ALLOW')
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 1)[1], 'ALLOW')
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 2)[1], 'ALLOW')
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 3)[1], 'ALLOW')
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 4)[1], 'KILL_PROCESS')
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 5)[1], 'KILL_PROCESS')
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 6)[1], 'KILL_PROCESS')
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 7)[1], 'KILL_PROCESS')
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 8)[1], 'KILL_PROCESS')
+
+ def test_arg0_short_gt_ge_comparisons(self):
+ """Ensure that the short comparison optimization kicks in."""
+ if self.arch.bits == 32:
+ return
+ short_constant_str = '0xdeadbeef'
+ short_constant = int(short_constant_str, base=0)
+ long_constant_str = '0xbadc0ffee0ddf00d'
+ long_constant = int(long_constant_str, base=0)
+ biases = (-1, 0, 1)
+ # For each operator, store the expectations of simulating the program
+ # against the constant plus each entry from the |biases| array.
+ cases = (
+ ('<', ('ALLOW', 'KILL_PROCESS', 'KILL_PROCESS')),
+ ('<=', ('ALLOW', 'ALLOW', 'KILL_PROCESS')),
+ ('>', ('KILL_PROCESS', 'KILL_PROCESS', 'ALLOW')),
+ ('>=', ('KILL_PROCESS', 'ALLOW', 'ALLOW')),
+ )
+ for operator, expectations in cases:
+ short_block = self._compile(
+ 'read: arg0 %s %s' % (operator, short_constant_str))
+ long_block = self._compile(
+ 'read: arg0 %s %s' % (operator, long_constant_str))
+
+ # Check that the emitted code is shorter when the high word of the
+ # constant is zero.
+ self.assertLess(
+ len(short_block.filter.instructions),
+ len(long_block.filter.instructions))
+
+ # Check the filter's behavior.
+ for bias, expectation in zip(biases, expectations):
+ self.assertEqual(
+ long_block.simulate(self.arch.arch_nr,
+ self.arch.syscalls['read'],
+ long_constant + bias)[1], expectation)
+ self.assertEqual(
+ short_block.simulate(
+ self.arch.arch_nr, self.arch.syscalls['read'],
+ short_constant + bias)[1], expectation)
+
+ def test_and_or(self):
+ """Accept lines with a complex expression in DNF."""
+ block = self._compile('read: arg0 == 0 && arg1 == 0 || arg0 == 1')
+
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], 0,
+ 0)[1], 'ALLOW')
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], 0,
+ 1)[1], 'KILL_PROCESS')
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], 1,
+ 0)[1], 'ALLOW')
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], 1,
+ 1)[1], 'ALLOW')
+
+ def test_ret_errno(self):
+ """Accept lines that return errno."""
+ block = self._compile('read : arg0 == 0 || arg0 == 1 ; return 1')
+
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 0)[1:], ('ERRNO', 1))
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 1)[1:], ('ERRNO', 1))
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 2)[1], 'KILL_PROCESS')
+
+ def test_ret_errno_unconditionally(self):
+ """Accept lines that return errno unconditionally."""
+ block = self._compile('read: return 1')
+
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
+ 0)[1:], ('ERRNO', 1))
+
+ def test_mmap_write_xor_exec(self):
+ """Accept the idiomatic filter for mmap."""
+ block = self._compile(
+ 'read : arg0 in ~PROT_WRITE || arg0 in ~PROT_EXEC')
+
+ prot_exec_and_write = 6
+ for prot in range(0, 0xf):
+ if (prot & prot_exec_and_write) == prot_exec_and_write:
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr,
+ self.arch.syscalls['read'], prot)[1],
+ 'KILL_PROCESS')
+ else:
+ self.assertEqual(
+ block.simulate(self.arch.arch_nr,
+ self.arch.syscalls['read'], prot)[1],
+ 'ALLOW')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tools/parser.py b/tools/parser.py
index 05b6628..41d2f52 100644
--- a/tools/parser.py
+++ b/tools/parser.py
@@ -21,8 +21,12 @@ from __future__ import division
from __future__ import print_function
import collections
+import itertools
+import os.path
import re
+import bpf
+
Token = collections.namedtuple('token',
['type', 'value', 'filename', 'line', 'column'])
@@ -30,7 +34,9 @@ Token = collections.namedtuple('token',
_TOKEN_SPECIFICATION = (
('COMMENT', r'#.*$'),
('WHITESPACE', r'\s+'),
+ ('DEFAULT', r'@default'),
('INCLUDE', r'@include'),
+ ('FREQUENCY', r'@frequency'),
('PATH', r'(?:\.)?/\S+'),
('NUMERIC_CONSTANT', r'-?0[xX][0-9a-fA-F]+|-?0[Oo][0-7]+|-?[0-9]+'),
('COLON', r':'),
@@ -137,12 +143,51 @@ class ParserState:
return tokens
+Atom = collections.namedtuple('Atom', ['argument_index', 'op', 'value'])
+"""A single boolean comparison within a filter expression."""
+
+Filter = collections.namedtuple('Filter', ['expression', 'action'])
+"""The result of parsing a DNF filter expression, with its action.
+
+Since the expression is in Disjunctive Normal Form, it is composed of two levels
+of lists, one for disjunctions and the inner one for conjunctions. The elements
+of the inner list are Atoms.
+"""
+
+Syscall = collections.namedtuple('Syscall', ['name', 'number'])
+"""A system call."""
+
+ParsedFilterStatement = collections.namedtuple('ParsedFilterStatement',
+ ['syscalls', 'filters'])
+"""The result of parsing a filter statement.
+
+Statements have a list of syscalls, and an associated list of filters that will
+be evaluated sequentially when any of the syscalls is invoked.
+"""
+
+FilterStatement = collections.namedtuple('FilterStatement',
+ ['syscall', 'frequency', 'filters'])
+"""The filter list for a particular syscall.
+
+This is a mapping from one syscall to a list of filters that are evaluated
+sequentially. The last filter is always an unconditional action.
+"""
+
+ParsedPolicy = collections.namedtuple('ParsedPolicy',
+ ['default_action', 'filter_statements'])
+"""The result of parsing a minijail .policy file."""
+
+
# pylint: disable=too-few-public-methods
class PolicyParser:
"""A parser for the Minijail seccomp policy file format."""
- def __init__(self, arch):
+ def __init__(self, arch, *, kill_action, include_depth_limit=10):
self._parser_states = [ParserState("<memory>")]
+ self._kill_action = kill_action
+ self._include_depth_limit = include_depth_limit
+ self._default_action = self._kill_action
+ self._frequency_mapping = collections.defaultdict(int)
self._arch = arch
@property
@@ -228,3 +273,439 @@ class PolicyParser:
else:
self._parser_state.error('empty constant')
return value
+
+ # atom = argument , op , value
+ # ;
+ def _parse_atom(self, tokens):
+ if not tokens:
+ self._parser_state.error('missing argument')
+ argument = tokens.pop(0)
+ if argument.type != 'ARGUMENT':
+ self._parser_state.error('invalid argument', token=argument)
+
+ if not tokens:
+ self._parser_state.error('missing operator')
+ operator = tokens.pop(0)
+ if operator.type != 'OP':
+ self._parser_state.error('invalid operator', token=operator)
+
+ value = self.parse_value(tokens)
+ argument_index = int(argument.value[3:])
+ if not (0 <= argument_index < bpf.MAX_SYSCALL_ARGUMENTS):
+ self._parser_state.error('invalid argument', token=argument)
+ return Atom(argument_index, operator.value, value)
+
+ # clause = atom , [ { '&&' , atom } ]
+ # ;
+ def _parse_clause(self, tokens):
+ atoms = []
+ while tokens:
+ atoms.append(self._parse_atom(tokens))
+ if not tokens or tokens[0].type != 'AND':
+ break
+ tokens.pop(0)
+ else:
+ self._parser_state.error('empty clause')
+ return atoms
+
+ # argument-expression = clause , [ { '||' , clause } ]
+ # ;
+ def parse_argument_expression(self, tokens):
+ """Parse a argument expression in Disjunctive Normal Form.
+
+ Since BPF disallows back jumps, we build the basic blocks in reverse
+ order so that all the jump targets are known by the time we need to
+ reference them.
+ """
+
+ clauses = []
+ while tokens:
+ clauses.append(self._parse_clause(tokens))
+ if not tokens or tokens[0].type != 'OR':
+ break
+ tokens.pop(0)
+ else:
+ self._parser_state.error('empty argument expression')
+ return clauses
+
+ # default-action = 'kill-process'
+ # | 'kill-thread'
+ # | 'kill'
+ # | 'trap'
+ # ;
+ def _parse_default_action(self, tokens):
+ if not tokens:
+ self._parser_state.error('missing default action')
+ action_token = tokens.pop(0)
+ if action_token.type != 'ACTION':
+ return self._parser_state.error(
+ 'invalid default action', token=action_token)
+ if action_token.value == 'kill-process':
+ return bpf.KillProcess()
+ if action_token.value == 'kill-thread':
+ return bpf.KillThread()
+ if action_token.value == 'kill':
+ return self._kill_action
+ if action_token.value == 'trap':
+ return bpf.Trap()
+ return self._parser_state.error(
+ 'invalid permissive default action', token=action_token)
+
+ # action = 'allow' | '1'
+ # | 'kill-process'
+ # | 'kill-thread'
+ # | 'kill'
+ # | 'trap'
+ # | 'trace'
+ # | 'log'
+ # | 'return' , single-constant
+ # ;
+ def _parse_action(self, tokens):
+ if not tokens:
+ self._parser_state.error('missing action')
+ action_token = tokens.pop(0)
+ if action_token.type == 'ACTION':
+ if action_token.value == 'allow':
+ return bpf.Allow()
+ if action_token.value == 'kill':
+ return self._kill_action
+ if action_token.value == 'kill-process':
+ return bpf.KillProcess()
+ if action_token.value == 'kill-thread':
+ return bpf.KillThread()
+ if action_token.value == 'trap':
+ return bpf.Trap()
+ if action_token.value == 'trace':
+ return bpf.Trace()
+ if action_token.value == 'log':
+ return bpf.Log()
+ elif action_token.type == 'NUMERIC_CONSTANT':
+ constant = self._parse_single_constant(action_token)
+ if constant == 1:
+ return bpf.Allow()
+ elif action_token.type == 'RETURN':
+ if not tokens:
+ self._parser_state.error('missing return value')
+ return bpf.ReturnErrno(self._parse_single_constant(tokens.pop(0)))
+ return self._parser_state.error('invalid action', token=action_token)
+
+ # single-filter = action
+ # | argument-expression , [ ';' , action ]
+ # ;
+ def _parse_single_filter(self, tokens):
+ if not tokens:
+ self._parser_state.error('missing filter')
+ if tokens[0].type == 'ARGUMENT':
+ # Only argument expressions can start with an ARGUMENT token.
+ argument_expression = self.parse_argument_expression(tokens)
+ if tokens and tokens[0].type == 'SEMICOLON':
+ tokens.pop(0)
+ action = self._parse_action(tokens)
+ else:
+ action = bpf.Allow()
+ return Filter(argument_expression, action)
+ else:
+ return Filter(None, self._parse_action(tokens))
+
+ # filter = '{' , single-filter , [ { ',' , single-filter } ] , '}'
+ # | single-filter
+ # ;
+ def parse_filter(self, tokens):
+ """Parse a filter and return a list of Filter objects."""
+ if not tokens:
+ self._parser_state.error('missing filter')
+ filters = []
+ if tokens[0].type == 'LBRACE':
+ opening_brace = tokens.pop(0)
+ while tokens:
+ filters.append(self._parse_single_filter(tokens))
+ if not tokens or tokens[0].type != 'COMMA':
+ break
+ tokens.pop(0)
+ if not tokens or tokens[0].type != 'RBRACE':
+ self._parser_state.error('unclosed brace', token=opening_brace)
+ tokens.pop(0)
+ else:
+ filters.append(self._parse_single_filter(tokens))
+ return filters
+
+ # key-value-pair = identifier , '=', identifier , [ { ',' , identifier } ]
+ # ;
+ def _parse_key_value_pair(self, tokens):
+ if not tokens:
+ self._parser_state.error('missing key')
+ key = tokens.pop(0)
+ if key.type != 'IDENTIFIER':
+ self._parser_state.error('invalid key', token=key)
+ if not tokens:
+ self._parser_state.error('missing equal')
+ if tokens[0].type != 'EQUAL':
+ self._parser_state.error('invalid equal', token=tokens[0])
+ tokens.pop(0)
+ value_list = []
+ while tokens:
+ value = tokens.pop(0)
+ if value.type != 'IDENTIFIER':
+ self._parser_state.error('invalid value', token=value)
+ value_list.append(value.value)
+ if not tokens or tokens[0].type != 'COMMA':
+ break
+ tokens.pop(0)
+ else:
+ self._parser_state.error('empty value')
+ return (key.value, value_list)
+
+ # metadata = '[' , key-value-pair , [ { ';' , key-value-pair } ] , ']'
+ # ;
+ def _parse_metadata(self, tokens):
+ if not tokens:
+ self._parser_state.error('missing opening bracket')
+ opening_bracket = tokens.pop(0)
+ if opening_bracket.type != 'LBRACKET':
+ self._parser_state.error(
+ 'invalid opening bracket', token=opening_bracket)
+ metadata = {}
+ while tokens:
+ first_token = tokens[0]
+ key, value = self._parse_key_value_pair(tokens)
+ if key in metadata:
+ self._parser_state.error(
+ 'duplicate metadata key: "%s"' % key, token=first_token)
+ metadata[key] = value
+ if not tokens or tokens[0].type != 'SEMICOLON':
+ break
+ tokens.pop(0)
+ if not tokens or tokens[0].type != 'RBRACKET':
+ self._parser_state.error('unclosed bracket', token=opening_bracket)
+ tokens.pop(0)
+ return metadata
+
+ # syscall-descriptor = syscall-name , [ metadata ]
+ # | libc-function , [ metadata ]
+ # ;
+ def _parse_syscall_descriptor(self, tokens):
+ if not tokens:
+ self._parser_state.error('missing syscall descriptor')
+ syscall_descriptor = tokens.pop(0)
+ if syscall_descriptor.type != 'IDENTIFIER':
+ self._parser_state.error(
+ 'invalid syscall descriptor', token=syscall_descriptor)
+ # TODO(lhchavez): Support libc function names.
+ if tokens and tokens[0].type == 'LBRACKET':
+ metadata = self._parse_metadata(tokens)
+ if 'arch' in metadata and self._arch.arch_name not in metadata['arch']:
+ return ()
+ if syscall_descriptor.value not in self._arch.syscalls:
+ self._parser_state.error(
+ 'nonexistent syscall', token=syscall_descriptor)
+ return (Syscall(syscall_descriptor.value,
+ self._arch.syscalls[syscall_descriptor.value]), )
+
+ # filter-statement = '{' , syscall-descriptor , [ { ',', syscall-descriptor } ] , '}' ,
+ # ':' , filter
+ # | syscall-descriptor , ':' , filter
+ # ;
+ def parse_filter_statement(self, tokens):
+ """Parse a filter statement and return a ParsedFilterStatement."""
+ if not tokens:
+ self._parser_state.error('empty filter statement')
+ syscall_descriptors = []
+ if tokens[0].type == 'LBRACE':
+ opening_brace = tokens.pop(0)
+ while tokens:
+ syscall_descriptors.extend(
+ self._parse_syscall_descriptor(tokens))
+ if not tokens or tokens[0].type != 'COMMA':
+ break
+ tokens.pop(0)
+ if not tokens or tokens[0].type != 'RBRACE':
+ self._parser_state.error('unclosed brace', token=opening_brace)
+ tokens.pop(0)
+ else:
+ syscall_descriptors.extend(self._parse_syscall_descriptor(tokens))
+ if not tokens:
+ self._parser_state.error('missing colon')
+ if tokens[0].type != 'COLON':
+ self._parser_state.error('invalid colon', token=tokens[0])
+ tokens.pop(0)
+ parsed_filter = self.parse_filter(tokens)
+ if not syscall_descriptors:
+ return None
+ return ParsedFilterStatement(tuple(syscall_descriptors), parsed_filter)
+
+ # include-statement = '@include' , posix-path
+ # ;
+ def _parse_include_statement(self, tokens):
+ if not tokens:
+ self._parser_state.error('empty filter statement')
+ if tokens[0].type != 'INCLUDE':
+ self._parser_state.error('invalid include', token=tokens[0])
+ tokens.pop(0)
+ if not tokens:
+ self._parser_state.error('empty include path')
+ include_path = tokens.pop(0)
+ if include_path.type != 'PATH':
+ self._parser_state.error(
+ 'invalid include path', token=include_path)
+ if len(self._parser_states) == self._include_depth_limit:
+ self._parser_state.error('@include statement nested too deep')
+ include_filename = os.path.normpath(
+ os.path.join(
+ os.path.dirname(self._parser_state.filename),
+ include_path.value))
+ if not os.path.isfile(include_filename):
+ self._parser_state.error(
+ 'Could not @include %s' % include_filename, token=include_path)
+ return self._parse_policy_file(include_filename)
+
+ def _parse_frequency_file(self, filename):
+ self._parser_states.append(ParserState(filename))
+ try:
+ frequency_mapping = collections.defaultdict(int)
+ with open(filename) as frequency_file:
+ for line in frequency_file:
+ self._parser_state.set_line(line.rstrip())
+ tokens = self._parser_state.tokenize()
+
+ if not tokens:
+ continue
+
+ syscall_numbers = self._parse_syscall_descriptor(tokens)
+ if not tokens:
+ self._parser_state.error('missing colon')
+ if tokens[0].type != 'COLON':
+ self._parser_state.error(
+ 'invalid colon', token=tokens[0])
+ tokens.pop(0)
+
+ if not tokens:
+ self._parser_state.error('missing number')
+ number = tokens.pop(0)
+ if number.type != 'NUMERIC_CONSTANT':
+ self._parser_state.error(
+ 'invalid number', token=number)
+ number_value = int(number.value, base=0)
+ if number_value < 0:
+ self._parser_state.error(
+ 'invalid number', token=number)
+
+ for syscall_number in syscall_numbers:
+ frequency_mapping[syscall_number] += number_value
+ return frequency_mapping
+ finally:
+ self._parser_states.pop()
+
+ # frequency-statement = '@frequency' , posix-path
+ # ;
+ def _parse_frequency_statement(self, tokens):
+ if not tokens:
+ self._parser_state.error('empty frequency statement')
+ if tokens[0].type != 'FREQUENCY':
+ self._parser_state.error('invalid frequency', token=tokens[0])
+ tokens.pop(0)
+ if not tokens:
+ self._parser_state.error('empty frequency path')
+ frequency_path = tokens.pop(0)
+ if frequency_path.type != 'PATH':
+ self._parser_state.error(
+ 'invalid frequency path', token=frequency_path)
+ frequency_filename = os.path.normpath(
+ os.path.join(
+ os.path.dirname(self._parser_state.filename),
+ frequency_path.value))
+ if not os.path.isfile(frequency_filename):
+ self._parser_state.error(
+ 'Could not open frequency file %s' % frequency_filename,
+ token=frequency_path)
+ return self._parse_frequency_file(frequency_filename)
+
+ # default-statement = '@default' , default-action
+ # ;
+ def _parse_default_statement(self, tokens):
+ if not tokens:
+ self._parser_state.error('empty default statement')
+ if tokens[0].type != 'DEFAULT':
+ self._parser_state.error('invalid default', token=tokens[0])
+ tokens.pop(0)
+ if not tokens:
+ self._parser_state.error('empty action')
+ return self._parse_default_action(tokens)
+
+ def _parse_policy_file(self, filename):
+ self._parser_states.append(ParserState(filename))
+ try:
+ statements = []
+ with open(filename) as policy_file:
+ for line in policy_file:
+ self._parser_state.set_line(line.rstrip())
+ tokens = self._parser_state.tokenize()
+
+ if not tokens:
+ # Allow empty lines.
+ continue
+
+ if tokens[0].type == 'INCLUDE':
+ statements.extend(
+ self._parse_include_statement(tokens))
+ elif tokens[0].type == 'FREQUENCY':
+ for syscall_number, frequency in self._parse_frequency_statement(
+ tokens).items():
+ self._frequency_mapping[
+ syscall_number] += frequency
+ elif tokens[0].type == 'DEFAULT':
+ self._default_action = self._parse_default_statement(
+ tokens)
+ else:
+ statement = self.parse_filter_statement(tokens)
+ if statement is None:
+ # If all the syscalls in the statement are for
+ # another arch, skip the whole statement.
+ continue
+ statements.append(statement)
+
+ if tokens:
+ self._parser_state.error(
+ 'extra tokens', token=tokens[0])
+ return statements
+ finally:
+ self._parser_states.pop()
+
+ def parse_file(self, filename):
+ """Parse a file and return the list of FilterStatements."""
+ self._frequency_mapping = collections.defaultdict(int)
+ try:
+ statements = [x for x in self._parse_policy_file(filename)]
+ except RecursionError:
+ raise ParseException('recursion limit exceeded', filename,
+ self._parser_states[-1].line)
+
+ # Collapse statements into a single syscall-to-filter-list.
+ syscall_filter_mapping = {}
+ filter_statements = []
+ for syscalls, filters in statements:
+ for syscall in syscalls:
+ if syscall not in syscall_filter_mapping:
+ filter_statements.append(
+ FilterStatement(
+ syscall, self._frequency_mapping.get(syscall, 1),
+ []))
+ syscall_filter_mapping[syscall] = filter_statements[-1]
+ syscall_filter_mapping[syscall].filters.extend(filters)
+ for filter_statement in filter_statements:
+ unconditional_actions_suffix = list(
+ itertools.dropwhile(lambda filt: filt.expression is not None,
+ filter_statement.filters))
+ if len(unconditional_actions_suffix) == 1:
+ # The last filter already has an unconditional action, no need
+ # to add another one.
+ continue
+ if len(unconditional_actions_suffix) > 1:
+ raise ParseException(('Syscall %s (number %d) already had '
+ 'an unconditional action applied') %
+ (filter_statement.syscall.name,
+ filter_statement.syscall.number),
+ filename, self._parser_states[-1].line)
+ assert not unconditional_actions_suffix
+ filter_statement.filters.append(
+ Filter(expression=None, action=self._default_action))
+ return ParsedPolicy(self._default_action, filter_statements)
diff --git a/tools/parser_unittest.py b/tools/parser_unittest.py
index d40ab42..4fba590 100755
--- a/tools/parser_unittest.py
+++ b/tools/parser_unittest.py
@@ -21,9 +21,12 @@ from __future__ import division
from __future__ import print_function
import os
+import shutil
+import tempfile
import unittest
import arch
+import bpf
import parser # pylint: disable=wrong-import-order
ARCH_64 = arch.Arch.load_from_json(
@@ -97,7 +100,8 @@ class ParseConstantTests(unittest.TestCase):
def setUp(self):
self.arch = ARCH_64
- self.parser = parser.PolicyParser(self.arch)
+ self.parser = parser.PolicyParser(
+ self.arch, kill_action=bpf.KillProcess())
def _tokenize(self, line):
# pylint: disable=protected-access
@@ -233,5 +237,401 @@ class ParseConstantTests(unittest.TestCase):
self.parser.parse_value(self._tokenize('0|'))
+class ParseFilterExpressionTests(unittest.TestCase):
+ """Tests for PolicyParser.parse_argument_expression."""
+
+ def setUp(self):
+ self.arch = ARCH_64
+ self.parser = parser.PolicyParser(
+ self.arch, kill_action=bpf.KillProcess())
+
+ def _tokenize(self, line):
+ # pylint: disable=protected-access
+ self.parser._parser_state.set_line(line)
+ return self.parser._parser_state.tokenize()
+
+ def test_parse_argument_expression(self):
+ """Accept valid argument expressions."""
+ self.assertEqual(
+ self.parser.parse_argument_expression(
+ self._tokenize(
+ 'arg0 in 0xffff || arg0 == PROT_EXEC && arg1 == PROT_WRITE'
+ )), [
+ [parser.Atom(0, 'in', 0xffff)],
+ [parser.Atom(0, '==', 4),
+ parser.Atom(1, '==', 2)],
+ ])
+
+ def test_parse_empty_argument_expression(self):
+ """Reject empty argument expressions."""
+ with self.assertRaisesRegex(parser.ParseException,
+ 'empty argument expression'):
+ self.parser.parse_argument_expression(
+ self._tokenize('arg0 in 0xffff ||'))
+
+ def test_parse_empty_clause(self):
+ """Reject empty clause."""
+ with self.assertRaisesRegex(parser.ParseException, 'empty clause'):
+ self.parser.parse_argument_expression(
+ self._tokenize('arg0 in 0xffff &&'))
+
+ def test_parse_invalid_argument(self):
+ """Reject invalid argument."""
+ with self.assertRaisesRegex(parser.ParseException, 'invalid argument'):
+ self.parser.parse_argument_expression(
+ self._tokenize('argX in 0xffff'))
+
+ def test_parse_invalid_operator(self):
+ """Reject invalid operator."""
+ with self.assertRaisesRegex(parser.ParseException, 'invalid operator'):
+ self.parser.parse_argument_expression(
+ self._tokenize('arg0 = 0xffff'))
+
+
+class ParseFilterTests(unittest.TestCase):
+ """Tests for PolicyParser.parse_filter."""
+
+ def setUp(self):
+ self.arch = ARCH_64
+ self.parser = parser.PolicyParser(
+ self.arch, kill_action=bpf.KillProcess())
+
+ def _tokenize(self, line):
+ # pylint: disable=protected-access
+ self.parser._parser_state.set_line(line)
+ return self.parser._parser_state.tokenize()
+
+ def test_parse_filter(self):
+ """Accept valid filters."""
+ self.assertEqual(
+ self.parser.parse_filter(self._tokenize('arg0 == 0')), [
+ parser.Filter([[parser.Atom(0, '==', 0)]], bpf.Allow()),
+ ])
+ self.assertEqual(
+ self.parser.parse_filter(self._tokenize('kill-process')), [
+ parser.Filter(None, bpf.KillProcess()),
+ ])
+ self.assertEqual(
+ self.parser.parse_filter(self._tokenize('kill-thread')), [
+ parser.Filter(None, bpf.KillThread()),
+ ])
+ self.assertEqual(
+ self.parser.parse_filter(self._tokenize('trap')), [
+ parser.Filter(None, bpf.Trap()),
+ ])
+ self.assertEqual(
+ self.parser.parse_filter(self._tokenize('return ENOSYS')), [
+ parser.Filter(None,
+ bpf.ReturnErrno(self.arch.constants['ENOSYS'])),
+ ])
+ self.assertEqual(
+ self.parser.parse_filter(self._tokenize('trace')), [
+ parser.Filter(None, bpf.Trace()),
+ ])
+ self.assertEqual(
+ self.parser.parse_filter(self._tokenize('log')), [
+ parser.Filter(None, bpf.Log()),
+ ])
+ self.assertEqual(
+ self.parser.parse_filter(self._tokenize('allow')), [
+ parser.Filter(None, bpf.Allow()),
+ ])
+ self.assertEqual(
+ self.parser.parse_filter(self._tokenize('1')), [
+ parser.Filter(None, bpf.Allow()),
+ ])
+ self.assertEqual(
+ self.parser.parse_filter(
+ self._tokenize(
+ '{ arg0 == 0, arg0 == 1; return ENOSYS, trap }')),
+ [
+ parser.Filter([[parser.Atom(0, '==', 0)]], bpf.Allow()),
+ parser.Filter([[parser.Atom(0, '==', 1)]],
+ bpf.ReturnErrno(self.arch.constants['ENOSYS'])),
+ parser.Filter(None, bpf.Trap()),
+ ])
+
+ def test_parse_missing_return_value(self):
+ """Reject missing return value."""
+ with self.assertRaisesRegex(parser.ParseException,
+ 'missing return value'):
+ self.parser.parse_filter(self._tokenize('return'))
+
+ def test_parse_invalid_return_value(self):
+ """Reject invalid return value."""
+ with self.assertRaisesRegex(parser.ParseException, 'invalid constant'):
+ self.parser.parse_filter(self._tokenize('return arg0'))
+
+ def test_parse_unclosed_brace(self):
+ """Reject unclosed brace."""
+ with self.assertRaisesRegex(parser.ParseException, 'unclosed brace'):
+ self.parser.parse_filter(self._tokenize('{ allow'))
+
+
+class ParseFilterStatementTests(unittest.TestCase):
+ """Tests for PolicyParser.parse_filter_statement."""
+
+ def setUp(self):
+ self.arch = ARCH_64
+ self.parser = parser.PolicyParser(
+ self.arch, kill_action=bpf.KillProcess())
+
+ def _tokenize(self, line):
+ # pylint: disable=protected-access
+ self.parser._parser_state.set_line(line)
+ return self.parser._parser_state.tokenize()
+
+ def test_parse_filter_statement(self):
+ """Accept valid filter statements."""
+ self.assertEqual(
+ self.parser.parse_filter_statement(
+ self._tokenize('read: arg0 == 0')),
+ parser.ParsedFilterStatement((parser.Syscall('read', 0), ), [
+ parser.Filter([[parser.Atom(0, '==', 0)]], bpf.Allow()),
+ ]))
+ self.assertEqual(
+ self.parser.parse_filter_statement(
+ self._tokenize('{read, write}: arg0 == 0')),
+ parser.ParsedFilterStatement((
+ parser.Syscall('read', 0),
+ parser.Syscall('write', 1),
+ ), [
+ parser.Filter([[parser.Atom(0, '==', 0)]], bpf.Allow()),
+ ]))
+
+ def test_parse_metadata(self):
+ """Accept valid filter statements with metadata."""
+ self.assertEqual(
+ self.parser.parse_filter_statement(
+ self._tokenize('read[arch=test]: arg0 == 0')),
+ parser.ParsedFilterStatement((parser.Syscall('read', 0), ), [
+ parser.Filter([[parser.Atom(0, '==', 0)]], bpf.Allow()),
+ ]))
+ self.assertEqual(
+ self.parser.parse_filter_statement(
+ self._tokenize(
+ '{read, nonexistent[arch=nonexistent]}: arg0 == 0')),
+ parser.ParsedFilterStatement((parser.Syscall('read', 0), ), [
+ parser.Filter([[parser.Atom(0, '==', 0)]], bpf.Allow()),
+ ]))
+
+ def test_parse_unclosed_brace(self):
+ """Reject unclosed brace."""
+ with self.assertRaisesRegex(parser.ParseException, 'unclosed brace'):
+ self.parser.parse_filter_statement(
+ self._tokenize('{ read, write: arg0 == 0'))
+
+ def test_parse_missing_colon(self):
+ """Reject missing colon."""
+ with self.assertRaisesRegex(parser.ParseException, 'missing colon'):
+ self.parser.parse_filter_statement(self._tokenize('read'))
+
+ def test_parse_invalid_colon(self):
+ """Reject invalid colon."""
+ with self.assertRaisesRegex(parser.ParseException, 'invalid colon'):
+ self.parser.parse_filter_statement(self._tokenize('read arg0'))
+
+ def test_parse_missing_filter(self):
+ """Reject missing filter."""
+ with self.assertRaisesRegex(parser.ParseException, 'missing filter'):
+ self.parser.parse_filter_statement(self._tokenize('read:'))
+
+
+class ParseFileTests(unittest.TestCase):
+ """Tests for PolicyParser.parse_file."""
+
+ def setUp(self):
+ self.arch = ARCH_64
+ self.parser = parser.PolicyParser(
+ self.arch, kill_action=bpf.KillProcess())
+ self.tempdir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def _write_file(self, filename, contents):
+ """Helper to write out a file for testing."""
+ path = os.path.join(self.tempdir, filename)
+ with open(path, 'w') as outf:
+ outf.write(contents)
+ return path
+
+ def test_parse_simple(self):
+ """Allow simple policy files."""
+ path = self._write_file(
+ 'test.policy', """
+ # Comment.
+ read: allow
+ write: allow
+ """)
+
+ self.assertEqual(
+ self.parser.parse_file(path),
+ parser.ParsedPolicy(
+ default_action=bpf.KillProcess(),
+ filter_statements=[
+ parser.FilterStatement(
+ syscall=parser.Syscall('read', 0),
+ frequency=1,
+ filters=[
+ parser.Filter(None, bpf.Allow()),
+ ]),
+ parser.FilterStatement(
+ syscall=parser.Syscall('write', 1),
+ frequency=1,
+ filters=[
+ parser.Filter(None, bpf.Allow()),
+ ]),
+ ]))
+
+ def test_parse_default(self):
+ """Allow defining a default action."""
+ path = self._write_file(
+ 'test.policy', """
+ @default kill-thread
+ read: allow
+ """)
+
+ self.assertEqual(
+ self.parser.parse_file(path),
+ parser.ParsedPolicy(
+ default_action=bpf.KillThread(),
+ filter_statements=[
+ parser.FilterStatement(
+ syscall=parser.Syscall('read', 0),
+ frequency=1,
+ filters=[
+ parser.Filter(None, bpf.Allow()),
+ ]),
+ ]))
+
+ def test_parse_default_permissive(self):
+ """Reject defining a permissive default action."""
+ path = self._write_file(
+ 'test.policy', """
+ @default log
+ read: allow
+ """)
+
+ with self.assertRaisesRegex(parser.ParseException,
+ r'invalid permissive default action'):
+ self.parser.parse_file(path)
+
+ def test_parse_simple_grouped(self):
+ """Allow simple policy files."""
+ path = self._write_file(
+ 'test.policy', """
+ # Comment.
+ {read, write}: allow
+ """)
+
+ self.assertEqual(
+ self.parser.parse_file(path),
+ parser.ParsedPolicy(
+ default_action=bpf.KillProcess(),
+ filter_statements=[
+ parser.FilterStatement(
+ syscall=parser.Syscall('read', 0),
+ frequency=1,
+ filters=[
+ parser.Filter(None, bpf.Allow()),
+ ]),
+ parser.FilterStatement(
+ syscall=parser.Syscall('write', 1),
+ frequency=1,
+ filters=[
+ parser.Filter(None, bpf.Allow()),
+ ]),
+ ]))
+
+ def test_parse_include(self):
+ """Allow including policy files."""
+ path = self._write_file(
+ 'test.include.policy', """
+ {read, write}: arg0 == 0; allow
+ """)
+ path = self._write_file(
+ 'test.policy', """
+ @include ./test.include.policy
+ read: return ENOSYS
+ """)
+
+ self.assertEqual(
+ self.parser.parse_file(path),
+ parser.ParsedPolicy(
+ default_action=bpf.KillProcess(),
+ filter_statements=[
+ parser.FilterStatement(
+ syscall=parser.Syscall('read', 0),
+ frequency=1,
+ filters=[
+ parser.Filter([[parser.Atom(0, '==', 0)]],
+ bpf.Allow()),
+ parser.Filter(
+ None,
+ bpf.ReturnErrno(
+ self.arch.constants['ENOSYS'])),
+ ]),
+ parser.FilterStatement(
+ syscall=parser.Syscall('write', 1),
+ frequency=1,
+ filters=[
+ parser.Filter([[parser.Atom(0, '==', 0)]],
+ bpf.Allow()),
+ parser.Filter(None, bpf.KillProcess()),
+ ]),
+ ]))
+
+ def test_parse_frequency(self):
+ """Allow including frequency files."""
+ self._write_file(
+ 'test.frequency', """
+ read: 2
+ write: 3
+ """)
+ path = self._write_file(
+ 'test.policy', """
+ @frequency ./test.frequency
+ read: allow
+ """)
+
+ self.assertEqual(
+ self.parser.parse_file(path),
+ parser.ParsedPolicy(
+ default_action=bpf.KillProcess(),
+ filter_statements=[
+ parser.FilterStatement(
+ syscall=parser.Syscall('read', 0),
+ frequency=2,
+ filters=[
+ parser.Filter(None, bpf.Allow()),
+ ]),
+ ]))
+
+ def test_parse_multiple_unconditional(self):
+ """Reject actions after an unconditional action."""
+ path = self._write_file(
+ 'test.policy', """
+ read: allow
+ read: allow
+ """)
+
+ with self.assertRaisesRegex(
+ parser.ParseException,
+ r'Syscall read.*already had an unconditional action applied'):
+ self.parser.parse_file(path)
+
+ path = self._write_file(
+ 'test.policy', """
+ read: log
+ read: arg0 == 0; log
+ """)
+
+ with self.assertRaisesRegex(
+ parser.ParseException,
+ r'Syscall read.*already had an unconditional action applied'):
+ self.parser.parse_file(path)
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/tools/testdata/arch_64.json b/tools/testdata/arch_64.json
index c23f988..bd3e2f4 100644
--- a/tools/testdata/arch_64.json
+++ b/tools/testdata/arch_64.json
@@ -4,9 +4,12 @@
"bits": 64,
"syscalls": {
"read": 0,
- "write": 1
+ "write": 1,
+ "open": 2,
+ "close": 3
},
"constants": {
+ "ENOSYS": 38,
"O_RDONLY": 0,
"PROT_WRITE": 2,
"PROT_EXEC": 4