diff options
Diffstat (limited to 'tests/test_common_utils.py')
-rw-r--r-- | tests/test_common_utils.py | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/tests/test_common_utils.py b/tests/test_common_utils.py index 56398be..09b31e3 100644 --- a/tests/test_common_utils.py +++ b/tests/test_common_utils.py @@ -14,8 +14,10 @@ # from bart.common import Utils +from bart.common.Analyzer import Analyzer import unittest import pandas as pd +import trappy class TestCommonUtils(unittest.TestCase): @@ -96,3 +98,33 @@ class TestCommonUtils(unittest.TestCase): method="rect", step="pre"), 0) + + +class TestAnalyzer(unittest.TestCase): + + def test_assert_statement_bool(self): + """Check that asssertStatement() works with a simple boolean case""" + + rolls_dfr = pd.DataFrame({"results": [1, 3, 2, 6, 2, 4]}) + trace = trappy.BareTrace() + trace.add_parsed_event("dice_rolls", rolls_dfr) + config = {"MAX_DICE_NUMBER": 6} + + t = Analyzer(trace, config) + statement = "numpy.max(dice_rolls:results) <= MAX_DICE_NUMBER" + self.assertTrue(t.assertStatement(statement, select=0)) + + def test_assert_statement_dataframe(self): + """assertStatement() works if the generated statement creates a pandas.DataFrame of bools""" + + rolls_dfr = pd.DataFrame({"results": [1, 3, 2, 6, 2, 4]}) + trace = trappy.BareTrace() + trace.add_parsed_event("dice_rolls", rolls_dfr) + config = {"MIN_DICE_NUMBER": 1, "MAX_DICE_NUMBER": 6} + t = Analyzer(trace, config) + + statement = "(dice_rolls:results <= MAX_DICE_NUMBER) & (dice_rolls:results >= MIN_DICE_NUMBER)" + self.assertTrue(t.assertStatement(statement)) + + statement = "dice_rolls:results == 3" + self.assertFalse(t.assertStatement(statement)) |