aboutsummaryrefslogtreecommitdiff
path: root/gd/cert/truth.py
blob: eba6be343ced237a10eaf4990909e8b9acb53a6c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#!/usr/bin/env python3
#
#   Copyright 2020 - The Android Open Source Project
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

from datetime import timedelta

from mobly.asserts import assert_true
from mobly.asserts import assert_false
from mobly import signals

from cert.event_stream import IEventStream
from cert.event_stream import NOT_FOR_YOU_assert_event_occurs
from cert.event_stream import NOT_FOR_YOU_assert_all_events_occur
from cert.event_stream import NOT_FOR_YOU_assert_none_matching
from cert.event_stream import NOT_FOR_YOU_assert_none


class ObjectSubject(object):

    def __init__(self, value):
        self._value = value

    def isEqualTo(self, other):
        if self._value != other:
            raise signals.TestFailure("Expected \"%s\" to be equal to \"%s\"" % (self._value, other), extras=None)

    def isNotEqualTo(self, other):
        if self._value == other:
            raise signals.TestFailure("Expected \"%s\" to not be equal to \"%s\"" % (self._value, other), extras=None)

    def isNone(self):
        if self._value is not None:
            raise signals.TestFailure("Expected \"%s\" to be None" % self._value, extras=None)

    def isNotNone(self):
        if self._value is None:
            raise signals.TestFailure("Expected \"%s\" to not be None" % self._value, extras=None)


DEFAULT_TIMEOUT = timedelta(seconds=3)


class EventStreamSubject(ObjectSubject):

    def __init__(self, value):
        super().__init__(value)

    def emits(self, *match_fns, at_least_times=1, timeout=DEFAULT_TIMEOUT):
        if len(match_fns) == 0:
            raise signals.TestFailure("Must specify a match function")
        elif len(match_fns) == 1:
            NOT_FOR_YOU_assert_event_occurs(self._value, match_fns[0], at_least_times=at_least_times, timeout=timeout)
            return EventStreamContinuationSubject(self._value)
        else:
            return MultiMatchStreamSubject(self._value, match_fns, timeout)

    def emitsNone(self, *match_fns, timeout=DEFAULT_TIMEOUT):
        if len(match_fns) == 0:
            NOT_FOR_YOU_assert_none(self._value, timeout=timeout)
            return EventStreamContinuationSubject(self._value)
        elif len(match_fns) == 1:
            NOT_FOR_YOU_assert_none_matching(self._value, match_fns[0], timeout=timeout)
            return EventStreamContinuationSubject(self._value)
        else:
            raise signals.TestFailure("Cannot specify multiple match functions")


class MultiMatchStreamSubject(object):

    def __init__(self, stream, match_fns, timeout):
        self._stream = stream
        self._match_fns = match_fns
        self._timeout = timeout

    def inAnyOrder(self):
        NOT_FOR_YOU_assert_all_events_occur(self._stream, self._match_fns, order_matters=False, timeout=self._timeout)
        return EventStreamContinuationSubject(self._stream)

    def inOrder(self):
        NOT_FOR_YOU_assert_all_events_occur(self._stream, self._match_fns, order_matters=True, timeout=self._timeout)
        return EventStreamContinuationSubject(self._stream)


class EventStreamContinuationSubject(ObjectSubject):

    def __init__(self, value):
        super().__init__(value)

    def then(self, *match_fns, at_least_times=1, timeout=DEFAULT_TIMEOUT):
        if len(match_fns) == 0:
            raise signals.TestFailure("Must specify a match function")
        elif len(match_fns) == 1:
            NOT_FOR_YOU_assert_event_occurs(self._value, match_fns[0], at_least_times=at_least_times, timeout=timeout)
            return EventStreamContinuationSubject(self._value)
        else:
            return MultiMatchStreamSubject(self._value, match_fns, timeout)

    def thenNone(self, *match_fns, timeout=DEFAULT_TIMEOUT):
        if len(match_fns) == 0:
            NOT_FOR_YOU_assert_none(self._value, timeout=timeout)
            return EventStreamContinuationSubject(self._value)
        elif len(match_fns) == 1:
            NOT_FOR_YOU_assert_none_matching(self._value, match_fns[0], timeout=timeout)
            return EventStreamContinuationSubject(self._value)
        else:
            raise signals.TestFailure("Cannot specify multiple match functions")


class BooleanSubject(ObjectSubject):

    def __init__(self, value):
        super().__init__(value)

    def isTrue(self):
        assert_true(self._value, "")

    def isFalse(self):
        assert_false(self._value, "")


class TimeDeltaSubject(ObjectSubject):

    def __init__(self, value):
        super().__init__(value)

    def isWithin(self, time_bound):
        assert_true(self._value < time_bound, "")


def assertThat(subject):
    if type(subject) is bool:
        return BooleanSubject(subject)
    elif isinstance(subject, IEventStream):
        return EventStreamSubject(subject)
    elif isinstance(subject, timedelta):
        return TimeDeltaSubject(subject)
    else:
        return ObjectSubject(subject)