diff options
Diffstat (limited to 'pw_tokenizer/py')
-rw-r--r-- | pw_tokenizer/py/BUILD.bazel | 7 | ||||
-rwxr-xr-x | pw_tokenizer/py/detokenize_test.py | 345 | ||||
-rwxr-xr-x | pw_tokenizer/py/generate_hash_test_data.py | 112 | ||||
-rwxr-xr-x | pw_tokenizer/py/pw_tokenizer/database.py | 3 | ||||
-rwxr-xr-x | pw_tokenizer/py/pw_tokenizer/detokenize.py | 488 | ||||
-rw-r--r-- | pw_tokenizer/py/pw_tokenizer/encode.py | 3 | ||||
-rw-r--r-- | pw_tokenizer/py/pw_tokenizer/proto/__init__.py | 4 | ||||
-rw-r--r-- | pw_tokenizer/py/pw_tokenizer/serial_detokenizer.py | 2 | ||||
-rw-r--r-- | pw_tokenizer/py/pw_tokenizer/tokens.py | 12 | ||||
-rw-r--r-- | pw_tokenizer/py/setup.py | 22 |
10 files changed, 766 insertions, 232 deletions
diff --git a/pw_tokenizer/py/BUILD.bazel b/pw_tokenizer/py/BUILD.bazel index 540495a66..299d29ce4 100644 --- a/pw_tokenizer/py/BUILD.bazel +++ b/pw_tokenizer/py/BUILD.bazel @@ -30,6 +30,7 @@ py_library( "pw_tokenizer/serial_detokenizer.py", "pw_tokenizer/tokens.py", ], + imports = ["."], deps = [ "//pw_cli/py:pw_cli", ], @@ -46,7 +47,7 @@ py_binary( # This test attempts to directly access files in the source tree, which is # incompatible with sandboxing. -# TODO(b/241307309): Fix this test. +# TODO: b/241307309 - Fix this test. filegroup( name = "database_test", srcs = ["database_test.py"], @@ -83,7 +84,7 @@ proto_library( ], ) -# TODO(b/241456982): This target can't be built due to limitations of +# TODO: b/241456982 - This target can't be built due to limitations of # py_proto_library. # py_proto_library( # name = "detokenize_proto_test_pb2", @@ -101,7 +102,6 @@ py_test( ], deps = [ ":pw_tokenizer", - "@rules_python//python/runfiles", ], ) @@ -113,7 +113,6 @@ py_test( ], deps = [ ":pw_tokenizer", - "@rules_python//python/runfiles", ], ) diff --git a/pw_tokenizer/py/detokenize_test.py b/pw_tokenizer/py/detokenize_test.py index df710c7e9..36bb1fa6e 100755 --- a/pw_tokenizer/py/detokenize_test.py +++ b/pw_tokenizer/py/detokenize_test.py @@ -15,12 +15,15 @@ """Tests for detokenize.""" import base64 +import concurrent import datetime as dt +import functools import io import os from pathlib import Path import struct import tempfile +from typing import Any, Callable, NamedTuple, Tuple import unittest from unittest import mock @@ -451,6 +454,35 @@ class DetokenizeWithCollisions(unittest.TestCase): self.assertIn('#0 -1', repr(unambiguous)) +class ManualPoolExecutor(concurrent.futures.Executor): + """A stubbed pool executor that captures the most recent work request + and holds it until the public process method is manually called.""" + + def __init__(self): + super().__init__() + self._func = None + + # pylint: disable=arguments-differ + def submit(self, func, *args, **kwargs): + """Submits work to the pool, stashing the partial for later use.""" + self._func = functools.partial(func, *args, **kwargs) + + def process(self): + """Processes the latest func submitted to the pool.""" + if self._func is not None: + self._func() + self._func = None + + +class InlinePoolExecutor(concurrent.futures.Executor): + """A stubbed pool executor that runs work immediately, inline.""" + + # pylint: disable=arguments-differ + def submit(self, func, *args, **kwargs): + """Submits work to the pool, stashing the partial for later use.""" + func(*args, **kwargs) + + @mock.patch('os.path.getmtime') class AutoUpdatingDetokenizerTest(unittest.TestCase): """Tests the AutoUpdatingDetokenizer class.""" @@ -478,18 +510,79 @@ class AutoUpdatingDetokenizerTest(unittest.TestCase): try: file.close() + pool = ManualPoolExecutor() detok = detokenize.AutoUpdatingDetokenizer( - file.name, min_poll_period_s=0 + file.name, min_poll_period_s=0, pool=pool ) self.assertFalse(detok.detokenize(JELLO_WORLD_TOKEN).ok()) with open(file.name, 'wb') as fd: tokens.write_binary(db, fd) + # After the change but before the pool runs in another thread, + # the token should not exist. + self.assertFalse(detok.detokenize(JELLO_WORLD_TOKEN).ok()) + + # After the pool is allowed to process, it should. + pool.process() self.assertTrue(detok.detokenize(JELLO_WORLD_TOKEN).ok()) finally: os.unlink(file.name) + def test_update_with_directory(self, mock_getmtime): + """Tests the update command with a directory format database.""" + db = database.load_token_database( + io.BytesIO(ELF_WITH_TOKENIZER_SECTIONS) + ) + self.assertEqual(len(db), TOKENS_IN_ELF) + + the_time = [100] + + def move_back_time_if_file_exists(path): + if os.path.exists(path): + the_time[0] -= 1 + return the_time[0] + + raise FileNotFoundError + + mock_getmtime.side_effect = move_back_time_if_file_exists + + with tempfile.TemporaryDirectory() as dbdir: + with tempfile.NamedTemporaryFile( + 'wb', delete=False, suffix='.pw_tokenizer.csv', dir=dbdir + ) as matching_suffix_file, tempfile.NamedTemporaryFile( + 'wb', delete=False, suffix='.not.right', dir=dbdir + ) as mismatched_suffix_file: + try: + matching_suffix_file.close() + mismatched_suffix_file.close() + + pool = ManualPoolExecutor() + detok = detokenize.AutoUpdatingDetokenizer( + dbdir, min_poll_period_s=0, pool=pool + ) + self.assertFalse(detok.detokenize(JELLO_WORLD_TOKEN).ok()) + + with open(mismatched_suffix_file.name, 'wb') as fd: + tokens.write_csv(db, fd) + pool.process() + self.assertFalse(detok.detokenize(JELLO_WORLD_TOKEN).ok()) + + with open(matching_suffix_file.name, 'wb') as fd: + tokens.write_csv(db, fd) + + # After the change but before the pool runs in another + # thread, the token should not exist. + self.assertFalse(detok.detokenize(JELLO_WORLD_TOKEN).ok()) + pool.process() + + # After the pool is allowed to process, it should. + self.assertTrue(detok.detokenize(JELLO_WORLD_TOKEN).ok()) + finally: + os.unlink(mismatched_suffix_file.name) + os.unlink(matching_suffix_file.name) + os.rmdir(dbdir) + # The database stays around if the file is deleted. self.assertTrue(detok.detokenize(JELLO_WORLD_TOKEN).ok()) @@ -507,7 +600,7 @@ class AutoUpdatingDetokenizerTest(unittest.TestCase): file.close() detok = detokenize.AutoUpdatingDetokenizer( - file.name, min_poll_period_s=0 + file.name, min_poll_period_s=0, pool=InlinePoolExecutor() ) self.assertTrue(detok.detokenize(JELLO_WORLD_TOKEN).ok()) @@ -527,7 +620,9 @@ class AutoUpdatingDetokenizerTest(unittest.TestCase): def test_token_domain_in_str(self, _) -> None: """Tests a str containing a domain""" detok = detokenize.AutoUpdatingDetokenizer( - f'{ELF_WITH_TOKENIZER_SECTIONS_PATH}#.*', min_poll_period_s=0 + f'{ELF_WITH_TOKENIZER_SECTIONS_PATH}#.*', + min_poll_period_s=0, + pool=InlinePoolExecutor(), ) self.assertEqual( len(detok.database), TOKENS_IN_ELF_WITH_TOKENIZER_SECTIONS @@ -536,7 +631,9 @@ class AutoUpdatingDetokenizerTest(unittest.TestCase): def test_token_domain_in_path(self, _) -> None: """Tests a Path() containing a domain""" detok = detokenize.AutoUpdatingDetokenizer( - Path(f'{ELF_WITH_TOKENIZER_SECTIONS_PATH}#.*'), min_poll_period_s=0 + Path(f'{ELF_WITH_TOKENIZER_SECTIONS_PATH}#.*'), + min_poll_period_s=0, + pool=InlinePoolExecutor(), ) self.assertEqual( len(detok.database), TOKENS_IN_ELF_WITH_TOKENIZER_SECTIONS @@ -545,14 +642,18 @@ class AutoUpdatingDetokenizerTest(unittest.TestCase): def test_token_no_domain_in_str(self, _) -> None: """Tests a str without a domain""" detok = detokenize.AutoUpdatingDetokenizer( - str(ELF_WITH_TOKENIZER_SECTIONS_PATH), min_poll_period_s=0 + str(ELF_WITH_TOKENIZER_SECTIONS_PATH), + min_poll_period_s=0, + pool=InlinePoolExecutor(), ) self.assertEqual(len(detok.database), TOKENS_IN_ELF) def test_token_no_domain_in_path(self, _) -> None: """Tests a Path() without a domain""" detok = detokenize.AutoUpdatingDetokenizer( - ELF_WITH_TOKENIZER_SECTIONS_PATH, min_poll_period_s=0 + ELF_WITH_TOKENIZER_SECTIONS_PATH, + min_poll_period_s=0, + pool=InlinePoolExecutor(), ) self.assertEqual(len(detok.database), TOKENS_IN_ELF) @@ -561,39 +662,173 @@ def _next_char(message: bytes) -> bytes: return bytes(b + 1 for b in message) -class PrefixedMessageDecoderTest(unittest.TestCase): - def setUp(self): - super().setUp() - self.decode = detokenize.PrefixedMessageDecoder('$', 'abcdefg') +class NestedMessageParserTest(unittest.TestCase): + """Tests parsing prefixed messages.""" + + class _Case(NamedTuple): + data: bytes + expected: bytes + title: str + transform: Callable[[bytes], bytes] = _next_char + + TRANSFORM_TEST_CASES = ( + _Case(b'$abcd', b'%bcde', 'single message'), + _Case( + b'$$WHAT?$abc$WHY? is this $ok $', + b'%%WHAT?%bcd%WHY? is this %ok %', + 'message and non-message', + ), + _Case(b'$1$', b'%1%', 'empty message'), + _Case(b'$abc$defgh', b'%bcd%efghh', 'sequential message'), + _Case( + b'w$abcx$defygh$$abz', + b'w$ABCx$DEFygh$$ABz', + 'interspersed start/end non-message', + bytes.upper, + ), + _Case( + b'$abcx$defygh$$ab', + b'$ABCx$DEFygh$$AB', + 'interspersed start/end message ', + bytes.upper, + ), + ) + + def setUp(self) -> None: + self.decoder = detokenize.NestedMessageParser('$', 'abcdefg') + + def test_transform_io(self) -> None: + for data, expected, title, transform in self.TRANSFORM_TEST_CASES: + self.assertEqual( + expected, + b''.join( + self.decoder.transform_io(io.BytesIO(data), transform) + ), + f'{title}: {data!r}', + ) + + def test_transform_bytes_with_flush(self) -> None: + for data, expected, title, transform in self.TRANSFORM_TEST_CASES: + self.assertEqual( + expected, + self.decoder.transform(data, transform, flush=True), + f'{title}: {data!r}', + ) + + def test_transform_bytes_sequential(self) -> None: + transform = lambda message: message.upper().replace(b'$', b'*') - def test_transform_single_message(self): + self.assertEqual(self.decoder.transform(b'abc$abcd', transform), b'abc') + self.assertEqual(self.decoder.transform(b'$', transform), b'*ABCD') + self.assertEqual(self.decoder.transform(b'$b', transform), b'*') + self.assertEqual(self.decoder.transform(b'', transform), b'') + self.assertEqual(self.decoder.transform(b' ', transform), b'*B ') + self.assertEqual(self.decoder.transform(b'hello', transform), b'hello') + self.assertEqual(self.decoder.transform(b'?? $ab', transform), b'?? ') self.assertEqual( - b'%bcde', - b''.join(self.decode.transform(io.BytesIO(b'$abcd'), _next_char)), + self.decoder.transform(b'123$ab4$56$a', transform), b'*AB123*AB4*56' ) + self.assertEqual( + self.decoder.transform(b'bc', transform, flush=True), b'*ABC' + ) + + MESSAGES_TEST: Any = ( + (b'123$abc456$a', (False, b'123'), (True, b'$abc'), (False, b'456')), + (b'7$abcd', (True, b'$a'), (False, b'7')), + (b'e',), + (b'',), + (b'$', (True, b'$abcde')), + (b'$', (True, b'$')), + (b'$a$b$c', (True, b'$'), (True, b'$a'), (True, b'$b')), + (b'1', (True, b'$c'), (False, b'1')), + (b'',), + (b'?', (False, b'?')), + (b'!@', (False, b'!@')), + (b'%^&', (False, b'%^&')), + ) - def test_transform_message_amidst_other_only_affects_message(self): + def test_read_messages(self) -> None: + for step in self.MESSAGES_TEST: + data: bytes = step[0] + pieces: Tuple[Tuple[bool, bytes], ...] = step[1:] + self.assertEqual(tuple(self.decoder.read_messages(data)), pieces) + + def test_read_messages_flush(self) -> None: self.assertEqual( - b'%%WHAT?%bcd%WHY? is this %ok %', - b''.join( - self.decode.transform( - io.BytesIO(b'$$WHAT?$abc$WHY? is this $ok $'), _next_char - ) - ), + list(self.decoder.read_messages(b'123$a')), [(False, b'123')] ) + self.assertEqual(list(self.decoder.read_messages(b'b')), []) + self.assertEqual( + list(self.decoder.read_messages(b'', flush=True)), [(True, b'$ab')] + ) + + def test_read_messages_io(self) -> None: + # Rework the read_messages test data for stream input. + data = io.BytesIO(b''.join(step[0] for step in self.MESSAGES_TEST)) + expected_pieces = sum((step[1:] for step in self.MESSAGES_TEST), ()) + + result = self.decoder.read_messages_io(data) + for expected_is_message, expected_data in expected_pieces: + if expected_is_message: + is_message, piece = next(result) + self.assertTrue(is_message) + self.assertEqual(expected_data, piece) + else: # the IO version yields non-messages byte by byte + for byte in expected_data: + is_message, piece = next(result) + self.assertFalse(is_message) + self.assertEqual(bytes([byte]), piece) - def test_transform_empty_message(self): + +class DetokenizeNested(unittest.TestCase): + """Tests detokenizing nested tokens""" + + def test_nested_hashed_arg(self): + detok = detokenize.Detokenizer( + tokens.Database( + [ + tokens.TokenizedStringEntry(0xA, 'tokenized argument'), + tokens.TokenizedStringEntry( + 2, + 'This is a ' + '$#%08x', + ), + ] + ) + ) self.assertEqual( - b'%1%', - b''.join(self.decode.transform(io.BytesIO(b'$1$'), _next_char)), + str(detok.detokenize(b'\x02\0\0\0\x14')), + 'This is a tokenized argument', ) - def test_transform_sequential_messages(self): + def test_nested_base64_arg(self): + detok = detokenize.Detokenizer( + tokens.Database( + [ + tokens.TokenizedStringEntry(1, 'base64 argument'), + tokens.TokenizedStringEntry(2, 'This is a %s'), + ] + ) + ) self.assertEqual( - b'%bcd%efghh', - b''.join( - self.decode.transform(io.BytesIO(b'$abc$defgh'), _next_char) - ), + str(detok.detokenize(b'\x02\0\0\0\x09$AQAAAA==')), # token for 1 + 'This is a base64 argument', + ) + + def test_deeply_nested_arg(self): + detok = detokenize.Detokenizer( + tokens.Database( + [ + tokens.TokenizedStringEntry(1, '$10#0000000005'), + tokens.TokenizedStringEntry(2, 'This is a $#%08x'), + tokens.TokenizedStringEntry(3, 'deeply nested argument'), + tokens.TokenizedStringEntry(4, '$AQAAAA=='), + tokens.TokenizedStringEntry(5, '$AwAAAA=='), + ] + ) + ) + self.assertEqual( + str(detok.detokenize(b'\x02\0\0\0\x08')), # token for 4 + 'This is a deeply nested argument', ) @@ -627,6 +862,10 @@ class DetokenizeBase64(unittest.TestCase): (JELLO + b'$a' + JELLO + b'bcd', b'Jello, world!$aJello, world!bcd'), (b'$3141', b'$3141'), (JELLO + b'$3141', b'Jello, world!$3141'), + ( + JELLO + b'$a' + JELLO + b'b' + JELLO + b'c', + b'Jello, world!$aJello, world!bJello, world!c', + ), (RECURSION, b'The secret message is "Jello, world!"'), ( RECURSION_2, @@ -650,7 +889,7 @@ class DetokenizeBase64(unittest.TestCase): output = io.BytesIO() self.detok.detokenize_base64_live(io.BytesIO(data), output, '$') - self.assertEqual(expected, output.getvalue()) + self.assertEqual(expected, output.getvalue(), f'Input: {data!r}') def test_detokenize_base64_to_file(self): for data, expected in self.TEST_CASES: @@ -670,6 +909,52 @@ class DetokenizeBase64(unittest.TestCase): ) +class DetokenizeInfiniteRecursion(unittest.TestCase): + """Tests that infinite Base64 token recursion resolves.""" + + def setUp(self): + super().setUp() + self.detok = detokenize.Detokenizer( + tokens.Database( + [ + tokens.TokenizedStringEntry(0, '$AAAAAA=='), # token for 0 + tokens.TokenizedStringEntry(1, '$AgAAAA=='), # token for 2 + tokens.TokenizedStringEntry(2, '$#00000003'), # token for 3 + tokens.TokenizedStringEntry(3, '$AgAAAA=='), # token for 2 + ] + ) + ) + + def test_detokenize_self_recursion(self): + for depth in range(5): + self.assertEqual( + self.detok.detokenize_text( + b'This one is deep: $AAAAAA==', recursion=depth + ), + b'This one is deep: $AAAAAA==', + ) + + def test_detokenize_self_recursion_default(self): + self.assertEqual( + self.detok.detokenize_text( + b'This one is deep: $AAAAAA==', + ), + b'This one is deep: $AAAAAA==', + ) + + def test_detokenize_cyclic_recursion_even(self): + self.assertEqual( + self.detok.detokenize_text(b'I said "$AQAAAA=="', recursion=6), + b'I said "$AgAAAA=="', + ) + + def test_detokenize_cyclic_recursion_odd(self): + self.assertEqual( + self.detok.detokenize_text(b'I said "$AQAAAA=="', recursion=7), + b'I said "$#00000003"', + ) + + class DetokenizeBase64InfiniteRecursion(unittest.TestCase): """Tests that infinite Bas64 token recursion resolves.""" @@ -697,7 +982,7 @@ class DetokenizeBase64InfiniteRecursion(unittest.TestCase): def test_detokenize_self_recursion_default(self): self.assertEqual( - self.detok.detokenize_base64(b'This one is deep: $AAAAAA=='), + self.detok.detokenize_base64(b'This one is deep: $64#AAAAAA=='), b'This one is deep: $AAAAAA==', ) diff --git a/pw_tokenizer/py/generate_hash_test_data.py b/pw_tokenizer/py/generate_hash_test_data.py index b875f188f..b658b1fd2 100755 --- a/pw_tokenizer/py/generate_hash_test_data.py +++ b/pw_tokenizer/py/generate_hash_test_data.py @@ -23,7 +23,7 @@ from pw_tokenizer import tokens HASH_LENGTHS = 80, 96, 128 HASH_MACRO = 'PW_TOKENIZER_65599_FIXED_LENGTH_{}_HASH' -FILE_HEADER = """\ +SHARED_HEADER = """\ // Copyright {year} The Pigweed Authors // // Licensed under the Apache License, Version 2.0 (the "License"); you may not @@ -42,6 +42,9 @@ FILE_HEADER = """\ // // This file was generated by {script}. // To make changes, update the script and run it to generate new files. +""" + +CPP_HEADER = """\ #pragma once #include <cstddef> @@ -62,7 +65,7 @@ inline constexpr struct {{ """ -FILE_FOOTER = """ +CPP_FOOTER = """ }; // kHashTests // clang-format on @@ -70,14 +73,31 @@ FILE_FOOTER = """ } // namespace pw::tokenizer """ -_TEST_CASE = """{{ - std::string_view("{str}", {string_length}u), +_CPP_TEST_CASE = """{{ + std::string_view("{str}", {string_length}u), // NOLINT(bugprone-string-constructor) {hash_length}u, // fixed hash length UINT32_C({hash}), // Python-calculated hash {macro}("{str}"), // macro-calculated hash }}, """ +RUST_HEADER = """ +fn test_cases() -> Vec<TestCase> {{ + vec![ +""" + +RUST_FOOTER = """ + ] +} +""" + +_RUST_TEST_CASE = """ TestCase{{ + string: b"{str}", + hash_length: {hash_length}, + hash: {hash}, + }}, +""" + def _include_paths(lengths): return '\n'.join( @@ -89,7 +109,7 @@ def _include_paths(lengths): ) -def _test_case_at_length(data, hash_length): +def _test_case_at_length(test_case_template, data, hash_length): """Generates a test case for a particular hash length.""" if isinstance(data, str): @@ -100,7 +120,7 @@ def _test_case_at_length(data, hash_length): else: escaped_str = ''.join(r'\x{:02x}'.format(b) for b in data) - return _TEST_CASE.format( + return test_case_template.format( str=escaped_str, string_length=len(data), hash_length=hash_length, @@ -109,22 +129,23 @@ def _test_case_at_length(data, hash_length): ) -def test_case(data): +def test_case(test_case_template, data): return ''.join( - _test_case_at_length(data, length) for length in (80, 96, 128) + _test_case_at_length(test_case_template, data, length) + for length in (80, 96, 128) ) -def generate_test_cases(): - yield test_case('') - yield test_case(b'\xa1') - yield test_case(b'\xff') - yield test_case('\0') - yield test_case('\0\0') - yield test_case('a') - yield test_case('A') - yield test_case('hello, "world"') - yield test_case('YO' * 100) +def generate_test_cases(test_case_template): + yield test_case(test_case_template, '') + yield test_case(test_case_template, b'\xa1') + yield test_case(test_case_template, b'\xff') + yield test_case(test_case_template, '\0') + yield test_case(test_case_template, '\0\0') + yield test_case(test_case_template, 'a') + yield test_case(test_case_template, 'A') + yield test_case(test_case_template, 'hello, "world"') + yield test_case(test_case_template, 'YO' * 100) random.seed(600613) @@ -133,37 +154,60 @@ def generate_test_cases(): ) for i in range(1, 16): - yield test_case(random_string(i)) - yield test_case(random_string(i)) + yield test_case(test_case_template, random_string(i)) + yield test_case(test_case_template, random_string(i)) for length in HASH_LENGTHS: - yield test_case(random_string(length - 1)) - yield test_case(random_string(length)) - yield test_case(random_string(length + 1)) + yield test_case(test_case_template, random_string(length - 1)) + yield test_case(test_case_template, random_string(length)) + yield test_case(test_case_template, random_string(length + 1)) -if __name__ == '__main__': +def generate_file( + path_array, header_template, footer_template, test_case_template +): path = os.path.realpath( - os.path.join( - os.path.dirname(__file__), - '..', - 'pw_tokenizer_private', - 'generated_hash_test_cases.h', - ) + os.path.join(os.path.dirname(__file__), *path_array) ) with open(path, 'w') as output: output.write( - FILE_HEADER.format( + SHARED_HEADER.format( year=datetime.date.today().year, script=os.path.basename(__file__), + ) + ) + output.write( + header_template.format( includes=_include_paths(HASH_LENGTHS), ) ) - for case in generate_test_cases(): + for case in generate_test_cases(test_case_template): output.write(case) - output.write(FILE_FOOTER) + output.write(footer_template) + print('Wrote test data to', path) + - print('Wrote test data to', path) +if __name__ == '__main__': + generate_file( + [ + '..', + 'pw_tokenizer_private', + 'generated_hash_test_cases.h', + ], + CPP_HEADER, + CPP_FOOTER, + _CPP_TEST_CASE, + ) + generate_file( + [ + '..', + 'rust', + 'pw_tokenizer_core_test_cases.rs', + ], + RUST_HEADER, + RUST_FOOTER, + _RUST_TEST_CASE, + ) diff --git a/pw_tokenizer/py/pw_tokenizer/database.py b/pw_tokenizer/py/pw_tokenizer/database.py index 26a32a7fa..54d142eff 100755 --- a/pw_tokenizer/py/pw_tokenizer/database.py +++ b/pw_tokenizer/py/pw_tokenizer/database.py @@ -297,6 +297,9 @@ def _handle_create( f'The file {database} already exists! Use --force to overwrite.' ) + if not database.parent.exists(): + database.parent.mkdir(parents=True) + if output_type == 'directory': if str(database) == '-': raise ValueError( diff --git a/pw_tokenizer/py/pw_tokenizer/detokenize.py b/pw_tokenizer/py/pw_tokenizer/detokenize.py index 3aa7a3a8b..c777252da 100755 --- a/pw_tokenizer/py/pw_tokenizer/detokenize.py +++ b/pw_tokenizer/py/pw_tokenizer/detokenize.py @@ -20,11 +20,11 @@ or a file object for an ELF file or CSV. Then, call the detokenize method with encoded messages, one at a time. The detokenize method returns a DetokenizedString object with the result. -For example, +For example:: from pw_tokenizer import detokenize - detok = detokenize.Detokenizer('path/to/my/image.elf') + detok = detokenize.Detokenizer('path/to/firmware/image.elf') print(detok.detokenize(b'\x12\x34\x56\x78\x03hi!')) This module also provides a command line interface for decoding and detokenizing @@ -34,6 +34,8 @@ messages from a file or stdin. import argparse import base64 import binascii +from concurrent.futures import Executor, ThreadPoolExecutor +import enum import io import logging import os @@ -42,6 +44,7 @@ import re import string import struct import sys +import threading import time from typing import ( AnyStr, @@ -70,10 +73,48 @@ except ImportError: _LOG = logging.getLogger('pw_tokenizer') ENCODED_TOKEN = struct.Struct('<I') -BASE64_PREFIX = encode.BASE64_PREFIX.encode() +_BASE64_CHARS = string.ascii_letters + string.digits + '+/-_=' DEFAULT_RECURSION = 9 +NESTED_TOKEN_PREFIX = encode.NESTED_TOKEN_PREFIX.encode() +NESTED_TOKEN_BASE_PREFIX = encode.NESTED_TOKEN_BASE_PREFIX.encode() + +_BASE8_TOKEN_REGEX = rb'(?P<base8>[0-7]{11})' +_BASE10_TOKEN_REGEX = rb'(?P<base10>[0-9]{10})' +_BASE16_TOKEN_REGEX = rb'(?P<base16>[A-Fa-f0-9]{8})' +_BASE64_TOKEN_REGEX = ( + rb'(?P<base64>' + # Tokenized Base64 contains 0 or more blocks of four Base64 chars. + rb'(?:[A-Za-z0-9+/\-_]{4})*' + # The last block of 4 chars may have one or two padding chars (=). + rb'(?:[A-Za-z0-9+/\-_]{3}=|[A-Za-z0-9+/\-_]{2}==)?' + rb')' +) +_NESTED_TOKEN_FORMATS = ( + _BASE8_TOKEN_REGEX, + _BASE10_TOKEN_REGEX, + _BASE16_TOKEN_REGEX, + _BASE64_TOKEN_REGEX, +) + +_RawIo = Union[io.RawIOBase, BinaryIO] +_RawIoOrBytes = Union[_RawIo, bytes] + -_RawIO = Union[io.RawIOBase, BinaryIO] +def _token_regex(prefix: bytes) -> Pattern[bytes]: + """Returns a regular expression for prefixed tokenized strings.""" + return re.compile( + # Tokenized strings start with the prefix character ($). + re.escape(prefix) + # Optional; no base specifier defaults to BASE64. + # Hash (#) with no number specified defaults to Base-16. + + rb'(?P<basespec>(?P<base>[0-9]*)?' + + NESTED_TOKEN_BASE_PREFIX + + rb')?' + # Match one of the following token formats. + + rb'(' + + rb'|'.join(_NESTED_TOKEN_FORMATS) + + rb')' + ) class DetokenizedString: @@ -85,6 +126,7 @@ class DetokenizedString: format_string_entries: Iterable[tuple], encoded_message: bytes, show_errors: bool = False, + recursive_detokenize: Optional[Callable[[str], str]] = None, ): self.token = token self.encoded_message = encoded_message @@ -99,6 +141,12 @@ class DetokenizedString: result = fmt.format( encoded_message[ENCODED_TOKEN.size :], show_errors ) + if recursive_detokenize: + result = decode.FormattedString( + recursive_detokenize(result.value), + result.args, + result.remaining, + ) decode_attempts.append((result.score(entry.date_removed), result)) # Sort the attempts by the score so the most likely results are first. @@ -186,28 +234,39 @@ class Detokenizer: """ self.show_errors = show_errors + self._database_lock = threading.Lock() + # Cache FormatStrings for faster lookup & formatting. self._cache: Dict[int, List[_TokenizedFormatString]] = {} self._initialize_database(token_database_or_elf) def _initialize_database(self, token_sources: Iterable) -> None: - self.database = database.load_token_database(*token_sources) - self._cache.clear() + with self._database_lock: + self.database = database.load_token_database(*token_sources) + self._cache.clear() def lookup(self, token: int) -> List[_TokenizedFormatString]: """Returns (TokenizedStringEntry, FormatString) list for matches.""" - try: - return self._cache[token] - except KeyError: - format_strings = [ - _TokenizedFormatString(entry, decode.FormatString(str(entry))) - for entry in self.database.token_to_entries[token] - ] - self._cache[token] = format_strings - return format_strings - - def detokenize(self, encoded_message: bytes) -> DetokenizedString: + with self._database_lock: + try: + return self._cache[token] + except KeyError: + format_strings = [ + _TokenizedFormatString( + entry, decode.FormatString(str(entry)) + ) + for entry in self.database.token_to_entries[token] + ] + self._cache[token] = format_strings + return format_strings + + def detokenize( + self, + encoded_message: bytes, + prefix: Union[str, bytes] = NESTED_TOKEN_PREFIX, + recursion: int = DEFAULT_RECURSION, + ) -> DetokenizedString: """Decodes and detokenizes a message as a DetokenizedString.""" if not encoded_message: return DetokenizedString( @@ -222,14 +281,25 @@ class Detokenizer: encoded_message += b'\0' * missing_token_bytes (token,) = ENCODED_TOKEN.unpack_from(encoded_message) + + recursive_detokenize = None + if recursion > 0: + recursive_detokenize = self._detokenize_nested_callback( + prefix, recursion + ) + return DetokenizedString( - token, self.lookup(token), encoded_message, self.show_errors + token, + self.lookup(token), + encoded_message, + self.show_errors, + recursive_detokenize, ) - def detokenize_base64( + def detokenize_text( self, data: AnyStr, - prefix: Union[str, bytes] = BASE64_PREFIX, + prefix: Union[str, bytes] = NESTED_TOKEN_PREFIX, recursion: int = DEFAULT_RECURSION, ) -> AnyStr: """Decodes and replaces prefixed Base64 messages in the provided data. @@ -242,88 +312,174 @@ class Detokenizer: Returns: copy of the data with all recognized tokens decoded """ - output = io.BytesIO() - self.detokenize_base64_to_file(data, output, prefix, recursion) - result = output.getvalue() - return result.decode() if isinstance(data, str) else result + return self._detokenize_nested_callback(prefix, recursion)(data) - def detokenize_base64_to_file( + # TODO(gschen): remove unnecessary function + def detokenize_base64( + self, + data: AnyStr, + prefix: Union[str, bytes] = NESTED_TOKEN_PREFIX, + recursion: int = DEFAULT_RECURSION, + ) -> AnyStr: + """Alias of detokenize_text for backwards compatibility.""" + return self.detokenize_text(data, prefix, recursion) + + def detokenize_text_to_file( self, - data: Union[str, bytes], + data: AnyStr, output: BinaryIO, - prefix: Union[str, bytes] = BASE64_PREFIX, + prefix: Union[str, bytes] = NESTED_TOKEN_PREFIX, recursion: int = DEFAULT_RECURSION, ) -> None: """Decodes prefixed Base64 messages in data; decodes to output file.""" - data = data.encode() if isinstance(data, str) else data - prefix = prefix.encode() if isinstance(prefix, str) else prefix + output.write(self._detokenize_nested(data, prefix, recursion)) - output.write( - _base64_message_regex(prefix).sub( - self._detokenize_prefixed_base64(prefix, recursion), data - ) - ) + # TODO(gschen): remove unnecessary function + def detokenize_base64_to_file( + self, + data: AnyStr, + output: BinaryIO, + prefix: Union[str, bytes] = NESTED_TOKEN_PREFIX, + recursion: int = DEFAULT_RECURSION, + ) -> None: + """Alias of detokenize_text_to_file for backwards compatibility.""" + self.detokenize_text_to_file(data, output, prefix, recursion) - def detokenize_base64_live( + def detokenize_text_live( self, - input_file: _RawIO, + input_file: _RawIo, output: BinaryIO, - prefix: Union[str, bytes] = BASE64_PREFIX, + prefix: Union[str, bytes] = NESTED_TOKEN_PREFIX, recursion: int = DEFAULT_RECURSION, ) -> None: """Reads chars one-at-a-time, decoding messages; SLOW for big files.""" - prefix_bytes = prefix.encode() if isinstance(prefix, str) else prefix - - base64_message = _base64_message_regex(prefix_bytes) def transform(data: bytes) -> bytes: - return base64_message.sub( - self._detokenize_prefixed_base64(prefix_bytes, recursion), data - ) + return self._detokenize_nested(data.decode(), prefix, recursion) - for message in PrefixedMessageDecoder( - prefix, string.ascii_letters + string.digits + '+/-_=' - ).transform(input_file, transform): + for message in NestedMessageParser(prefix, _BASE64_CHARS).transform_io( + input_file, transform + ): output.write(message) # Flush each line to prevent delays when piping between processes. if b'\n' in message: output.flush() - def _detokenize_prefixed_base64( - self, prefix: bytes, recursion: int - ) -> Callable[[Match[bytes]], bytes]: - """Returns a function that decodes prefixed Base64.""" + # TODO(gschen): remove unnecessary function + def detokenize_base64_live( + self, + input_file: _RawIo, + output: BinaryIO, + prefix: Union[str, bytes] = NESTED_TOKEN_PREFIX, + recursion: int = DEFAULT_RECURSION, + ) -> None: + """Alias of detokenize_text_live for backwards compatibility.""" + self.detokenize_text_live(input_file, output, prefix, recursion) - def decode_and_detokenize(match: Match[bytes]) -> bytes: - """Decodes prefixed base64 with this detokenizer.""" - original = match.group(0) + def _detokenize_nested_callback( + self, + prefix: Union[str, bytes], + recursion: int, + ) -> Callable[[AnyStr], AnyStr]: + """Returns a function that replaces all tokens for a given string.""" - try: - detokenized_string = self.detokenize( - base64.b64decode(original[1:], validate=True) - ) - if detokenized_string.matches(): - result = str(detokenized_string).encode() + def detokenize(message: AnyStr) -> AnyStr: + result = self._detokenize_nested(message, prefix, recursion) + return result.decode() if isinstance(message, str) else result + + return detokenize + + def _detokenize_nested( + self, + message: Union[str, bytes], + prefix: Union[str, bytes], + recursion: int, + ) -> bytes: + """Returns the message with recognized tokens replaced. + + Message data is internally handled as bytes regardless of input message + type and returns the result as bytes. + """ + # A unified format across the token types is required for regex + # consistency. + message = message.encode() if isinstance(message, str) else message + prefix = prefix.encode() if isinstance(prefix, str) else prefix + + if not self.database: + return message + + result = message + for _ in range(recursion - 1): + result = _token_regex(prefix).sub(self._detokenize_scan, result) + + if result == message: + return result + return result - if recursion > 0 and original != result: - result = self.detokenize_base64( - result, prefix, recursion - 1 - ) + def _detokenize_scan(self, match: Match[bytes]) -> bytes: + """Decodes prefixed tokens for one of multiple formats.""" + basespec = match.group('basespec') + base = match.group('base') - return result - except binascii.Error: - pass + if not basespec or (base == b'64'): + return self._detokenize_once_base64(match) + if not base: + base = b'16' + + return self._detokenize_once(match, base) + + def _detokenize_once( + self, + match: Match[bytes], + base: bytes, + ) -> bytes: + """Performs lookup on a plain token""" + original = match.group(0) + token = match.group('base' + base.decode()) + if not token: return original - return decode_and_detokenize + token = int(token, int(base)) + entries = self.database.token_to_entries[token] + + if len(entries) == 1: + return str(entries[0]).encode() + + # TODO(gschen): improve token collision reporting + + return original + + def _detokenize_once_base64( + self, + match: Match[bytes], + ) -> bytes: + """Performs lookup on a Base64 token""" + original = match.group(0) + + try: + encoded_token = match.group('base64') + if not encoded_token: + return original + + detokenized_string = self.detokenize( + base64.b64decode(encoded_token, validate=True), recursion=0 + ) + + if detokenized_string.matches(): + return str(detokenized_string).encode() + + except binascii.Error: + pass + + return original _PathOrStr = Union[Path, str] -# TODO(b/265334753): Reuse this function in database.py:LoadTokenDatabases +# TODO: b/265334753 - Reuse this function in database.py:LoadTokenDatabases def _parse_domain(path: _PathOrStr) -> Tuple[Path, Optional[Pattern[str]]]: """Extracts an optional domain regex pattern suffix from a path""" @@ -364,6 +520,12 @@ class AutoUpdatingDetokenizer(Detokenizer): return True def _last_modified_time(self) -> Optional[float]: + if self.path.is_dir(): + mtime = -1.0 + for child in self.path.glob(tokens.DIR_DB_GLOB): + mtime = max(mtime, os.path.getmtime(child)) + return mtime if mtime >= 0 else None + try: return os.path.getmtime(self.path) except FileNotFoundError: @@ -380,119 +542,181 @@ class AutoUpdatingDetokenizer(Detokenizer): return database.load_token_database() def __init__( - self, *paths_or_files: _PathOrStr, min_poll_period_s: float = 1.0 + self, + *paths_or_files: _PathOrStr, + min_poll_period_s: float = 1.0, + pool: Executor = ThreadPoolExecutor(max_workers=1), ) -> None: self.paths = tuple(self._DatabasePath(path) for path in paths_or_files) self.min_poll_period_s = min_poll_period_s self._last_checked_time: float = time.time() + # Thread pool to use for loading the databases. Limit to a single + # worker since this is low volume and not time critical. + self._pool = pool super().__init__(*(path.load() for path in self.paths)) + def __del__(self) -> None: + self._pool.shutdown(wait=False) + + def _reload_paths(self) -> None: + self._initialize_database([path.load() for path in self.paths]) + def _reload_if_changed(self) -> None: if time.time() - self._last_checked_time >= self.min_poll_period_s: self._last_checked_time = time.time() if any(path.updated() for path in self.paths): _LOG.info('Changes detected; reloading token database') - self._initialize_database(path.load() for path in self.paths) + self._pool.submit(self._reload_paths) def lookup(self, token: int) -> List[_TokenizedFormatString]: self._reload_if_changed() return super().lookup(token) -class PrefixedMessageDecoder: - """Parses messages that start with a prefix character from a byte stream.""" +class NestedMessageParser: + """Parses nested tokenized messages from a byte stream or string.""" - def __init__(self, prefix: Union[str, bytes], chars: Union[str, bytes]): - """Parses prefixed messages. + class _State(enum.Enum): + MESSAGE = 1 + NON_MESSAGE = 2 + + def __init__( + self, + prefix: Union[str, bytes] = NESTED_TOKEN_PREFIX, + chars: Union[str, bytes] = _BASE64_CHARS, + ) -> None: + """Initializes a parser. Args: - prefix: one character that signifies the start of a message - chars: characters allowed in a message + prefix: one character that signifies the start of a message (``$``). + chars: characters allowed in a message """ - self._prefix = prefix.encode() if isinstance(prefix, str) else prefix + self._prefix = ord(prefix) if isinstance(chars, str): chars = chars.encode() - # Store the valid message bytes as a set of binary strings. - self._message_bytes = frozenset( - chars[i : i + 1] for i in range(len(chars)) - ) + # Store the valid message bytes as a set of byte values. + self._message_bytes = frozenset(chars) - if len(self._prefix) != 1 or self._prefix in self._message_bytes: + if len(prefix) != 1 or self._prefix in self._message_bytes: raise ValueError( - 'Invalid prefix {!r}: the prefix must be a single ' - 'character that is not a valid message character.'.format( - prefix - ) + f'Invalid prefix {prefix!r}: the prefix must be a single ' + 'character that is not a valid message character.' ) - self.data = bytearray() + self._buffer = bytearray() + self._state: NestedMessageParser._State = self._State.NON_MESSAGE - def _read_next(self, fd: _RawIO) -> Tuple[bytes, int]: - """Returns the next character and its index.""" - char = fd.read(1) or b'' - index = len(self.data) - self.data += char - return char, index + def read_messages_io( + self, binary_io: _RawIo + ) -> Iterator[Tuple[bool, bytes]]: + """Reads prefixed messages from a byte stream (BinaryIO object). - def read_messages(self, binary_fd: _RawIO) -> Iterator[Tuple[bool, bytes]]: - """Parses prefixed messages; yields (is_message, contents) chunks.""" - message_start = None + Reads until EOF. If the stream is nonblocking (``read(1)`` returns + ``None``), then this function returns and may be called again with the + same IO object to continue parsing. Partial messages are preserved + between calls. - while True: - # This reads the file character-by-character. Non-message characters - # are yielded right away; message characters are grouped. - char, index = self._read_next(binary_fd) + Yields: + ``(is_message, contents)`` chunks. + """ + # The read may block indefinitely, depending on the IO object. + while (read_byte := binary_io.read(1)) != b'': + # Handle non-blocking IO by returning when no bytes are available. + if read_byte is None: + return - # If in a message, keep reading until the message completes. - if message_start is not None: - if char in self._message_bytes: - continue + for byte in read_byte: + yield from self._handle_byte(byte) - yield True, self.data[message_start:index] - message_start = None + if self._state is self._State.NON_MESSAGE: # yield non-message byte + yield from self._flush() - # Handle a non-message character. - if not char: - return + yield from self._flush() # Always flush after EOF + self._state = self._State.NON_MESSAGE - if char == self._prefix: - message_start = index - else: - yield False, char + def read_messages( + self, chunk: bytes, *, flush: bool = False + ) -> Iterator[Tuple[bool, bytes]]: + """Reads prefixed messages from a byte string. - def transform( - self, binary_fd: _RawIO, transform: Callable[[bytes], bytes] + This function may be called repeatedly with chunks of a stream. Partial + messages are preserved between calls, unless ``flush=True``. + + Args: + chunk: byte string that may contain nested messagses + flush: whether to flush any incomplete messages after processing + this chunk + + Yields: + ``(is_message, contents)`` chunks. + """ + for byte in chunk: + yield from self._handle_byte(byte) + + if flush or self._state is self._State.NON_MESSAGE: + yield from self._flush() + + def _handle_byte(self, byte: int) -> Iterator[Tuple[bool, bytes]]: + if self._state is self._State.MESSAGE: + if byte not in self._message_bytes: + yield from self._flush() + if byte != self._prefix: + self._state = self._State.NON_MESSAGE + elif self._state is self._State.NON_MESSAGE: + if byte == self._prefix: + yield from self._flush() + self._state = self._State.MESSAGE + else: + raise NotImplementedError(f'Unsupported state: {self._state}') + + self._buffer.append(byte) + + def _flush(self) -> Iterator[Tuple[bool, bytes]]: + data = bytes(self._buffer) + self._buffer.clear() + if data: + yield self._state is self._State.MESSAGE, data + + def transform_io( + self, + binary_io: _RawIo, + transform: Callable[[bytes], bytes], ) -> Iterator[bytes]: """Yields the file with a transformation applied to the messages.""" - for is_message, chunk in self.read_messages(binary_fd): + for is_message, chunk in self.read_messages_io(binary_io): yield transform(chunk) if is_message else chunk - -def _base64_message_regex(prefix: bytes) -> Pattern[bytes]: - """Returns a regular expression for prefixed base64 tokenized strings.""" - return re.compile( - # Base64 tokenized strings start with the prefix character ($) - re.escape(prefix) - + ( - # Tokenized strings contain 0 or more blocks of four Base64 chars. - br'(?:[A-Za-z0-9+/\-_]{4})*' - # The last block of 4 chars may have one or two padding chars (=). - br'(?:[A-Za-z0-9+/\-_]{3}=|[A-Za-z0-9+/\-_]{2}==)?' + def transform( + self, + chunk: bytes, + transform: Callable[[bytes], bytes], + *, + flush: bool = False, + ) -> bytes: + """Yields the chunk with a transformation applied to the messages. + + Partial messages are preserved between calls unless ``flush=True``. + """ + return b''.join( + transform(data) if is_message else data + for is_message, data in self.read_messages(chunk, flush=flush) ) - ) # TODO(hepler): Remove this unnecessary function. def detokenize_base64( detokenizer: Detokenizer, data: bytes, - prefix: Union[str, bytes] = BASE64_PREFIX, + prefix: Union[str, bytes] = NESTED_TOKEN_PREFIX, recursion: int = DEFAULT_RECURSION, ) -> bytes: - """Alias for detokenizer.detokenize_base64 for backwards compatibility.""" + """Alias for detokenizer.detokenize_base64 for backwards compatibility. + + This function is deprecated; do not call it. + """ return detokenizer.detokenize_base64(data, prefix, recursion) @@ -596,10 +820,10 @@ def _parse_args() -> argparse.Namespace: subparser.add_argument( '-p', '--prefix', - default=BASE64_PREFIX, + default=NESTED_TOKEN_PREFIX, help=( 'The one-character prefix that signals the start of a ' - 'Base64-encoded message. (default: $)' + 'nested tokenized message. (default: $)' ), ) subparser.add_argument( diff --git a/pw_tokenizer/py/pw_tokenizer/encode.py b/pw_tokenizer/py/pw_tokenizer/encode.py index 5b8583256..f47e0ec27 100644 --- a/pw_tokenizer/py/pw_tokenizer/encode.py +++ b/pw_tokenizer/py/pw_tokenizer/encode.py @@ -23,7 +23,8 @@ from pw_tokenizer import tokens _INT32_MAX = 2**31 - 1 _UINT32_MAX = 2**32 - 1 -BASE64_PREFIX = '$' +NESTED_TOKEN_PREFIX = '$' +NESTED_TOKEN_BASE_PREFIX = '#' def _zig_zag_encode(value: int) -> int: diff --git a/pw_tokenizer/py/pw_tokenizer/proto/__init__.py b/pw_tokenizer/py/pw_tokenizer/proto/__init__.py index da11d5e5e..7e54835d2 100644 --- a/pw_tokenizer/py/pw_tokenizer/proto/__init__.py +++ b/pw_tokenizer/py/pw_tokenizer/proto/__init__.py @@ -36,7 +36,7 @@ def _tokenized_fields(proto: Message) -> Iterator[FieldDescriptor]: def decode_optionally_tokenized( detokenizer: detokenize.Detokenizer, data: bytes, - prefix: str = encode.BASE64_PREFIX, + prefix: str = encode.NESTED_TOKEN_PREFIX, ) -> str: """Decodes data that may be plain text or binary / Base64 tokenized text.""" # Try detokenizing as binary. @@ -70,7 +70,7 @@ def decode_optionally_tokenized( def detokenize_fields( detokenizer: detokenize.Detokenizer, proto: Message, - prefix: str = encode.BASE64_PREFIX, + prefix: str = encode.NESTED_TOKEN_PREFIX, ) -> None: """Detokenizes fields annotated as tokenized in the given proto. diff --git a/pw_tokenizer/py/pw_tokenizer/serial_detokenizer.py b/pw_tokenizer/py/pw_tokenizer/serial_detokenizer.py index ab673a55e..79dda3a5e 100644 --- a/pw_tokenizer/py/pw_tokenizer/serial_detokenizer.py +++ b/pw_tokenizer/py/pw_tokenizer/serial_detokenizer.py @@ -59,7 +59,7 @@ def _parse_args(): parser.add_argument( '-p', '--prefix', - default=detokenize.BASE64_PREFIX, + default=detokenize.NESTED_TOKEN_PREFIX, help=( 'The one-character prefix that signals the start of a ' 'Base64-encoded message. (default: $)' diff --git a/pw_tokenizer/py/pw_tokenizer/tokens.py b/pw_tokenizer/py/pw_tokenizer/tokens.py index b7ebac87c..fa9339cdf 100644 --- a/pw_tokenizer/py/pw_tokenizer/tokens.py +++ b/pw_tokenizer/py/pw_tokenizer/tokens.py @@ -575,7 +575,7 @@ class _BinaryDatabase(DatabaseFile): def add_and_discard_temporary( self, entries: Iterable[TokenizedStringEntry], commit: str ) -> None: - # TODO(b/241471465): Implement adding new tokens and removing + # TODO: b/241471465 - Implement adding new tokens and removing # temporary entries for binary databases. raise NotImplementedError( '--discard-temporary is currently only ' @@ -597,7 +597,7 @@ class _CSVDatabase(DatabaseFile): def add_and_discard_temporary( self, entries: Iterable[TokenizedStringEntry], commit: str ) -> None: - # TODO(b/241471465): Implement adding new tokens and removing + # TODO: b/241471465 - Implement adding new tokens and removing # temporary entries for CSV databases. raise NotImplementedError( '--discard-temporary is currently only ' @@ -607,12 +607,12 @@ class _CSVDatabase(DatabaseFile): # The suffix used for CSV files in a directory database. DIR_DB_SUFFIX = '.pw_tokenizer.csv' -_DIR_DB_GLOB = '*' + DIR_DB_SUFFIX +DIR_DB_GLOB = '*' + DIR_DB_SUFFIX def _parse_directory(directory: Path) -> Iterable[TokenizedStringEntry]: """Parses TokenizedStringEntries tokenizer CSV files in the directory.""" - for path in directory.glob(_DIR_DB_GLOB): + for path in directory.glob(DIR_DB_GLOB): yield from _CSVDatabase(path).entries() @@ -633,7 +633,7 @@ class _DirectoryDatabase(DatabaseFile): write_csv(self, fd) # Delete all CSV files except for the new CSV with everything. - for csv_file in self.path.glob(_DIR_DB_GLOB): + for csv_file in self.path.glob(DIR_DB_GLOB): if csv_file != new_file: csv_file.unlink() else: @@ -648,7 +648,7 @@ class _DirectoryDatabase(DatabaseFile): """Returns a list of files from a Git command, filtered to matc.""" try: output = subprocess.run( - ['git', *commands, _DIR_DB_GLOB], + ['git', *commands, DIR_DB_GLOB], capture_output=True, check=True, cwd=self.path, diff --git a/pw_tokenizer/py/setup.py b/pw_tokenizer/py/setup.py deleted file mode 100644 index fd5c1e6f7..000000000 --- a/pw_tokenizer/py/setup.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright 2021 The Pigweed Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not -# use this file except in compliance with the License. You may obtain a copy of -# the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations under -# the License. -"""The pw_tokenizer package. - -Installing pw_tokenizer with this setup.py does not include the -pw_tokenizer.proto package, since it contains a generated protobuf module. To -access pw_tokenizer.proto, install pw_tokenizer from GN.""" - -import setuptools # type: ignore - -setuptools.setup() # Package definition in setup.cfg |