From b0718adc6d4b365527cb835a2200f5a4515fe941 Mon Sep 17 00:00:00 2001 From: Javi Merino Date: Thu, 14 Jan 2016 18:38:20 +0000 Subject: tests: add a basic assertStatement() check --- tests/test_common_utils.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_common_utils.py b/tests/test_common_utils.py index 56398be..9226ec3 100644 --- a/tests/test_common_utils.py +++ b/tests/test_common_utils.py @@ -14,8 +14,10 @@ # from bart.common import Utils +from bart.common.Analyzer import Analyzer import unittest import pandas as pd +import trappy class TestCommonUtils(unittest.TestCase): @@ -96,3 +98,18 @@ class TestCommonUtils(unittest.TestCase): method="rect", step="pre"), 0) + + +class TestAnalyzer(unittest.TestCase): + + def test_assert_statement_bool(self): + """Check that asssertStatement() works with a simple boolean case""" + + rolls_dfr = pd.DataFrame({"results": [1, 3, 2, 6, 2, 4]}) + trace = trappy.BareTrace() + trace.add_parsed_event("dice_rolls", rolls_dfr) + config = {"MAX_DICE_NUMBER": 6} + + t = Analyzer(trace, config) + statement = "numpy.max(dice_rolls:results) <= MAX_DICE_NUMBER" + self.assertTrue(t.assertStatement(statement, select=0)) -- cgit v1.2.3 From 5af9d234eb3445a36c34695c5d1b1bd8b88d5c6f Mon Sep 17 00:00:00 2001 From: Javi Merino Date: Thu, 14 Jan 2016 18:57:24 +0000 Subject: Analyzer: assert when the parsed statement returns a dataframe of bools Sometimes it's useful to assert things like: "event:column == 3". Teach assertStatement() handle the case where the result of parsing a statement is a dataframe of bools. --- bart/common/Analyzer.py | 11 ++++++----- tests/test_common_utils.py | 15 +++++++++++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/bart/common/Analyzer.py b/bart/common/Analyzer.py index 51194d7..d9dc74d 100644 --- a/bart/common/Analyzer.py +++ b/bart/common/Analyzer.py @@ -22,6 +22,7 @@ implemented yet. from trappy.stats.grammar import Parser import warnings import numpy as np +import pandas as pd # pylint: disable=invalid-name @@ -56,12 +57,12 @@ class Analyzer(object): result = self.getStatement(statement, select=select) - # pylint: disable=no-member - if not (isinstance(result, bool) or isinstance(result, np.bool_)): - warnings.warn( - "solution of {} is not an instance of bool".format(statement)) + if isinstance(result, pd.DataFrame): + result = result.all().all() + elif not(isinstance(result, bool) or isinstance(result, np.bool_)): # pylint: disable=no-member + warnings.warn("solution of {} is not boolean".format(statement)) + return result - # pylint: enable=no-member def getStatement(self, statement, reference=False, select=None): """Evaluate the statement""" diff --git a/tests/test_common_utils.py b/tests/test_common_utils.py index 9226ec3..09b31e3 100644 --- a/tests/test_common_utils.py +++ b/tests/test_common_utils.py @@ -113,3 +113,18 @@ class TestAnalyzer(unittest.TestCase): t = Analyzer(trace, config) statement = "numpy.max(dice_rolls:results) <= MAX_DICE_NUMBER" self.assertTrue(t.assertStatement(statement, select=0)) + + def test_assert_statement_dataframe(self): + """assertStatement() works if the generated statement creates a pandas.DataFrame of bools""" + + rolls_dfr = pd.DataFrame({"results": [1, 3, 2, 6, 2, 4]}) + trace = trappy.BareTrace() + trace.add_parsed_event("dice_rolls", rolls_dfr) + config = {"MIN_DICE_NUMBER": 1, "MAX_DICE_NUMBER": 6} + t = Analyzer(trace, config) + + statement = "(dice_rolls:results <= MAX_DICE_NUMBER) & (dice_rolls:results >= MIN_DICE_NUMBER)" + self.assertTrue(t.assertStatement(statement)) + + statement = "dice_rolls:results == 3" + self.assertFalse(t.assertStatement(statement)) -- cgit v1.2.3