diff options
-rw-r--r-- | .pylintrc | 1 | ||||
-rw-r--r-- | pw_cli/py/pw_cli/envparse.py | 150 | ||||
-rw-r--r-- | pw_cli/py/pw_cli/envparse_test.py | 143 |
3 files changed, 294 insertions, 0 deletions
@@ -299,6 +299,7 @@ good-names=i, ex, fd, Run, + T, _ # Include a hint for the correct naming format with invalid-name. diff --git a/pw_cli/py/pw_cli/envparse.py b/pw_cli/py/pw_cli/envparse.py new file mode 100644 index 000000000..a5c101c99 --- /dev/null +++ b/pw_cli/py/pw_cli/envparse.py @@ -0,0 +1,150 @@ +# Copyright 2020 The Pigweed Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +"""The envparse module defines an environment variable parser.""" + +import argparse +import os +from typing import Callable, Dict, Generic, IO, Literal, Mapping, NamedTuple +from typing import Optional, TypeVar + + +class EnvNamespace(argparse.Namespace): # pylint: disable=too-few-public-methods + """Base class for parsed environment variable namespaces.""" + + +T = TypeVar('T') +TypeConversion = Callable[[str], T] + + +class VariableDescriptor(NamedTuple, Generic[T]): + name: str + type: TypeConversion[T] + default: Optional[T] + + +class EnvironmentValueError(Exception): + """Exception indicating a bad type conversion on an environment variable. + + Stores a reference to the lower-level exception from the type conversion + function through the __cause__ attribute for more detailed information on + the error. + """ + def __init__(self, variable: str, value: str): + self.variable: str = variable + self.value: str = value + super().__init__( + f'Bad value for environment variable {variable}: {value}') + + +class EnvironmentParser: + """Parser for environment variables. + + Args: + prefix: If provided, checks that all registered environment variables + start with the specified string. + error_on_unrecognized: If True and prefix is provided, will raise an + exception if the environment contains a variable with the specified + prefix that is not registered on the EnvironmentParser. + + Example: + + parser = envparse.EnvironmentParser(prefix='PW_') + parser.add_var('PW_LOG_LEVEL') + parser.add_var('PW_LOG_FILE', type=envparse.FileType('w')) + parser.add_var('PW_USE_COLOR', type=envparse.bool_type, default=False) + env = parser.parse_env() + + configure_logging(env.PW_LOG_LEVEL, env.PW_LOG_FILE) + """ + def __init__(self, + prefix: Optional[str] = None, + error_on_unrecognized: bool = True) -> None: + self._prefix: Optional[str] = prefix + self._error_on_unrecognized: bool = error_on_unrecognized + self._variables: Dict[str, VariableDescriptor] = {} + + def add_var( + self, + name: str, + type: TypeConversion[T] = str, # pylint: disable=redefined-builtin + default: Optional[T] = None, + ) -> None: + """Registers an environment variable. + + Args: + name: The environment variable's name. + type: Type conversion for the variable's value. + default: Default value for the variable. + + Raises: + ValueError: If prefix was provided to the constructor and name does + not start with the prefix. + """ + if self._prefix is not None and not name.startswith(self._prefix): + raise ValueError( + f'Variable {name} does not have prefix {self._prefix}') + + self._variables[name] = VariableDescriptor(name, type, default) + + def parse_env(self, + env: Optional[Mapping[str, str]] = None) -> EnvNamespace: + """Parses known environment variables into a namespace. + + Args: + env: Dictionary of environment variables. Defaults to os.environ. + + Raises: + EnvironmentValueError: If the type conversion fails. + """ + if env is None: + env = os.environ + + namespace = EnvNamespace() + + for var, desc in self._variables.items(): + if var not in env: + val = desc.default + else: + try: + val = desc.type(env[var]) + except Exception as err: + raise EnvironmentValueError(var, env[var]) from err + + setattr(namespace, var, val) + + if self._prefix is not None and self._error_on_unrecognized: + for var in env: + if var.startswith(self._prefix) and var not in self._variables: + raise ValueError( + f'Unrecognized environment variable {var}') + + return namespace + + def __repr__(self) -> str: + return f'{type(self).__name__}(prefix={self._prefix})' + + +def bool_type(value: str) -> bool: + return value == '1' or value.lower() == 'true' + + +OpenMode = Literal['r', 'rb', 'w', 'wb'] + + +class FileType: + def __init__(self, mode: OpenMode) -> None: + self._mode: OpenMode = mode + + def __call__(self, value: str) -> IO: + return open(value, self._mode) diff --git a/pw_cli/py/pw_cli/envparse_test.py b/pw_cli/py/pw_cli/envparse_test.py new file mode 100644 index 000000000..8962e5e51 --- /dev/null +++ b/pw_cli/py/pw_cli/envparse_test.py @@ -0,0 +1,143 @@ +# Copyright 2020 The Pigweed Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +"""Tests for pw_cli.envparse.""" + +import math +import unittest + +import pw_cli.envparse as envparse + +# pylint: disable=no-member + + +class ErrorError(Exception): + pass + + +def error(value: str): + raise ErrorError('error!') + + +class TestEnvironmentParser(unittest.TestCase): + """Tests for envparse.EnvironmentParser.""" + def setUp(self): + self.raw_env = { + 'PATH': '/bin:/usr/bin:/usr/local/bin', + 'FOO': '2020', + 'ReVeRsE': 'pigweed', + } + + self.parser = envparse.EnvironmentParser() + self.parser.add_var('PATH') + self.parser.add_var('FOO', type=int) + self.parser.add_var('BAR', type=bool) + self.parser.add_var('BAZ', type=float, default=math.pi) + self.parser.add_var('ReVeRsE', type=lambda s: s[::-1]) + self.parser.add_var('INT', type=int) + self.parser.add_var('ERROR', type=error) + + def test_string_value(self): + env = self.parser.parse_env(env=self.raw_env) + self.assertEqual(env.PATH, self.raw_env['PATH']) + + def test_int_value(self): + env = self.parser.parse_env(env=self.raw_env) + self.assertEqual(env.FOO, 2020) + + def test_custom_value(self): + env = self.parser.parse_env(env=self.raw_env) + self.assertEqual(env.ReVeRsE, 'deewgip') + + def test_empty_value(self): + env = self.parser.parse_env(env=self.raw_env) + self.assertEqual(env.BAR, None) + + def test_default_value(self): + env = self.parser.parse_env(env=self.raw_env) + self.assertEqual(env.BAZ, math.pi) + + def test_unknown_key(self): + env = self.parser.parse_env(env=self.raw_env) + with self.assertRaises(AttributeError): + env.BBBBB # pylint: disable=pointless-statement + + def test_bad_value(self): + raw_env = {**self.raw_env, 'INT': 'not an int'} + with self.assertRaises(envparse.EnvironmentValueError) as ctx: + self.parser.parse_env(env=raw_env) + + self.assertEqual(ctx.exception.variable, 'INT') + self.assertIsInstance(ctx.exception.__cause__, ValueError) + + def test_custom_exception(self): + raw_env = {**self.raw_env, 'ERROR': 'error'} + with self.assertRaises(envparse.EnvironmentValueError) as ctx: + self.parser.parse_env(env=raw_env) + + self.assertEqual(ctx.exception.variable, 'ERROR') + self.assertIsInstance(ctx.exception.__cause__, ErrorError) + + +class TestEnvironmentParserWithPrefix(unittest.TestCase): + """Tests for envparse.EnvironmentParser using a prefix.""" + def setUp(self): + self.raw_env = { + 'PW_FOO': '001', + 'PW_BAR': '010', + 'PW_BAZ': '100', + } + + def test_parse_unrecognized_variable(self): + parser = envparse.EnvironmentParser(prefix='PW_') + parser.add_var('PW_FOO') + parser.add_var('PW_BAR') + + with self.assertRaises(ValueError): + parser.parse_env(env=self.raw_env) + + def test_parse_ignore_unrecognized(self): + parser = envparse.EnvironmentParser(prefix='PW_', + error_on_unrecognized=False) + parser.add_var('PW_FOO') + parser.add_var('PW_BAR') + + env = parser.parse_env(env=self.raw_env) + self.assertEqual(env.PW_FOO, self.raw_env['PW_FOO']) + self.assertEqual(env.PW_BAR, self.raw_env['PW_BAR']) + + def test_add_var_without_prefix(self): + parser = envparse.EnvironmentParser(prefix='PW_') + with self.assertRaises(ValueError): + parser.add_var('FOO') + + +class TestBoolType(unittest.TestCase): + """Tests for envparse.bool_type.""" + def setUp(self): + self.good_bools = ['true', '1', 'TRUE', 'tRuE'] + self.bad_bools = [ + '', 'false', '0', 'foo', '2', '999', 'ok', 'yes', 'no' + ] + + def test_good_bools(self): + self.assertTrue(all( + envparse.bool_type(val) for val in self.good_bools)) + + def test_bad_bools(self): + self.assertFalse(any( + envparse.bool_type(val) for val in self.bad_bools)) + + +if __name__ == '__main__': + unittest.main() |