aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.pylintrc1
-rw-r--r--pw_cli/py/pw_cli/envparse.py150
-rw-r--r--pw_cli/py/pw_cli/envparse_test.py143
3 files changed, 294 insertions, 0 deletions
diff --git a/.pylintrc b/.pylintrc
index 3445b61ff..816add7ed 100644
--- a/.pylintrc
+++ b/.pylintrc
@@ -299,6 +299,7 @@ good-names=i,
ex,
fd,
Run,
+ T,
_
# Include a hint for the correct naming format with invalid-name.
diff --git a/pw_cli/py/pw_cli/envparse.py b/pw_cli/py/pw_cli/envparse.py
new file mode 100644
index 000000000..a5c101c99
--- /dev/null
+++ b/pw_cli/py/pw_cli/envparse.py
@@ -0,0 +1,150 @@
+# Copyright 2020 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""The envparse module defines an environment variable parser."""
+
+import argparse
+import os
+from typing import Callable, Dict, Generic, IO, Literal, Mapping, NamedTuple
+from typing import Optional, TypeVar
+
+
+class EnvNamespace(argparse.Namespace): # pylint: disable=too-few-public-methods
+ """Base class for parsed environment variable namespaces."""
+
+
+T = TypeVar('T')
+TypeConversion = Callable[[str], T]
+
+
+class VariableDescriptor(NamedTuple, Generic[T]):
+ name: str
+ type: TypeConversion[T]
+ default: Optional[T]
+
+
+class EnvironmentValueError(Exception):
+ """Exception indicating a bad type conversion on an environment variable.
+
+ Stores a reference to the lower-level exception from the type conversion
+ function through the __cause__ attribute for more detailed information on
+ the error.
+ """
+ def __init__(self, variable: str, value: str):
+ self.variable: str = variable
+ self.value: str = value
+ super().__init__(
+ f'Bad value for environment variable {variable}: {value}')
+
+
+class EnvironmentParser:
+ """Parser for environment variables.
+
+ Args:
+ prefix: If provided, checks that all registered environment variables
+ start with the specified string.
+ error_on_unrecognized: If True and prefix is provided, will raise an
+ exception if the environment contains a variable with the specified
+ prefix that is not registered on the EnvironmentParser.
+
+ Example:
+
+ parser = envparse.EnvironmentParser(prefix='PW_')
+ parser.add_var('PW_LOG_LEVEL')
+ parser.add_var('PW_LOG_FILE', type=envparse.FileType('w'))
+ parser.add_var('PW_USE_COLOR', type=envparse.bool_type, default=False)
+ env = parser.parse_env()
+
+ configure_logging(env.PW_LOG_LEVEL, env.PW_LOG_FILE)
+ """
+ def __init__(self,
+ prefix: Optional[str] = None,
+ error_on_unrecognized: bool = True) -> None:
+ self._prefix: Optional[str] = prefix
+ self._error_on_unrecognized: bool = error_on_unrecognized
+ self._variables: Dict[str, VariableDescriptor] = {}
+
+ def add_var(
+ self,
+ name: str,
+ type: TypeConversion[T] = str, # pylint: disable=redefined-builtin
+ default: Optional[T] = None,
+ ) -> None:
+ """Registers an environment variable.
+
+ Args:
+ name: The environment variable's name.
+ type: Type conversion for the variable's value.
+ default: Default value for the variable.
+
+ Raises:
+ ValueError: If prefix was provided to the constructor and name does
+ not start with the prefix.
+ """
+ if self._prefix is not None and not name.startswith(self._prefix):
+ raise ValueError(
+ f'Variable {name} does not have prefix {self._prefix}')
+
+ self._variables[name] = VariableDescriptor(name, type, default)
+
+ def parse_env(self,
+ env: Optional[Mapping[str, str]] = None) -> EnvNamespace:
+ """Parses known environment variables into a namespace.
+
+ Args:
+ env: Dictionary of environment variables. Defaults to os.environ.
+
+ Raises:
+ EnvironmentValueError: If the type conversion fails.
+ """
+ if env is None:
+ env = os.environ
+
+ namespace = EnvNamespace()
+
+ for var, desc in self._variables.items():
+ if var not in env:
+ val = desc.default
+ else:
+ try:
+ val = desc.type(env[var])
+ except Exception as err:
+ raise EnvironmentValueError(var, env[var]) from err
+
+ setattr(namespace, var, val)
+
+ if self._prefix is not None and self._error_on_unrecognized:
+ for var in env:
+ if var.startswith(self._prefix) and var not in self._variables:
+ raise ValueError(
+ f'Unrecognized environment variable {var}')
+
+ return namespace
+
+ def __repr__(self) -> str:
+ return f'{type(self).__name__}(prefix={self._prefix})'
+
+
+def bool_type(value: str) -> bool:
+ return value == '1' or value.lower() == 'true'
+
+
+OpenMode = Literal['r', 'rb', 'w', 'wb']
+
+
+class FileType:
+ def __init__(self, mode: OpenMode) -> None:
+ self._mode: OpenMode = mode
+
+ def __call__(self, value: str) -> IO:
+ return open(value, self._mode)
diff --git a/pw_cli/py/pw_cli/envparse_test.py b/pw_cli/py/pw_cli/envparse_test.py
new file mode 100644
index 000000000..8962e5e51
--- /dev/null
+++ b/pw_cli/py/pw_cli/envparse_test.py
@@ -0,0 +1,143 @@
+# Copyright 2020 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""Tests for pw_cli.envparse."""
+
+import math
+import unittest
+
+import pw_cli.envparse as envparse
+
+# pylint: disable=no-member
+
+
+class ErrorError(Exception):
+ pass
+
+
+def error(value: str):
+ raise ErrorError('error!')
+
+
+class TestEnvironmentParser(unittest.TestCase):
+ """Tests for envparse.EnvironmentParser."""
+ def setUp(self):
+ self.raw_env = {
+ 'PATH': '/bin:/usr/bin:/usr/local/bin',
+ 'FOO': '2020',
+ 'ReVeRsE': 'pigweed',
+ }
+
+ self.parser = envparse.EnvironmentParser()
+ self.parser.add_var('PATH')
+ self.parser.add_var('FOO', type=int)
+ self.parser.add_var('BAR', type=bool)
+ self.parser.add_var('BAZ', type=float, default=math.pi)
+ self.parser.add_var('ReVeRsE', type=lambda s: s[::-1])
+ self.parser.add_var('INT', type=int)
+ self.parser.add_var('ERROR', type=error)
+
+ def test_string_value(self):
+ env = self.parser.parse_env(env=self.raw_env)
+ self.assertEqual(env.PATH, self.raw_env['PATH'])
+
+ def test_int_value(self):
+ env = self.parser.parse_env(env=self.raw_env)
+ self.assertEqual(env.FOO, 2020)
+
+ def test_custom_value(self):
+ env = self.parser.parse_env(env=self.raw_env)
+ self.assertEqual(env.ReVeRsE, 'deewgip')
+
+ def test_empty_value(self):
+ env = self.parser.parse_env(env=self.raw_env)
+ self.assertEqual(env.BAR, None)
+
+ def test_default_value(self):
+ env = self.parser.parse_env(env=self.raw_env)
+ self.assertEqual(env.BAZ, math.pi)
+
+ def test_unknown_key(self):
+ env = self.parser.parse_env(env=self.raw_env)
+ with self.assertRaises(AttributeError):
+ env.BBBBB # pylint: disable=pointless-statement
+
+ def test_bad_value(self):
+ raw_env = {**self.raw_env, 'INT': 'not an int'}
+ with self.assertRaises(envparse.EnvironmentValueError) as ctx:
+ self.parser.parse_env(env=raw_env)
+
+ self.assertEqual(ctx.exception.variable, 'INT')
+ self.assertIsInstance(ctx.exception.__cause__, ValueError)
+
+ def test_custom_exception(self):
+ raw_env = {**self.raw_env, 'ERROR': 'error'}
+ with self.assertRaises(envparse.EnvironmentValueError) as ctx:
+ self.parser.parse_env(env=raw_env)
+
+ self.assertEqual(ctx.exception.variable, 'ERROR')
+ self.assertIsInstance(ctx.exception.__cause__, ErrorError)
+
+
+class TestEnvironmentParserWithPrefix(unittest.TestCase):
+ """Tests for envparse.EnvironmentParser using a prefix."""
+ def setUp(self):
+ self.raw_env = {
+ 'PW_FOO': '001',
+ 'PW_BAR': '010',
+ 'PW_BAZ': '100',
+ }
+
+ def test_parse_unrecognized_variable(self):
+ parser = envparse.EnvironmentParser(prefix='PW_')
+ parser.add_var('PW_FOO')
+ parser.add_var('PW_BAR')
+
+ with self.assertRaises(ValueError):
+ parser.parse_env(env=self.raw_env)
+
+ def test_parse_ignore_unrecognized(self):
+ parser = envparse.EnvironmentParser(prefix='PW_',
+ error_on_unrecognized=False)
+ parser.add_var('PW_FOO')
+ parser.add_var('PW_BAR')
+
+ env = parser.parse_env(env=self.raw_env)
+ self.assertEqual(env.PW_FOO, self.raw_env['PW_FOO'])
+ self.assertEqual(env.PW_BAR, self.raw_env['PW_BAR'])
+
+ def test_add_var_without_prefix(self):
+ parser = envparse.EnvironmentParser(prefix='PW_')
+ with self.assertRaises(ValueError):
+ parser.add_var('FOO')
+
+
+class TestBoolType(unittest.TestCase):
+ """Tests for envparse.bool_type."""
+ def setUp(self):
+ self.good_bools = ['true', '1', 'TRUE', 'tRuE']
+ self.bad_bools = [
+ '', 'false', '0', 'foo', '2', '999', 'ok', 'yes', 'no'
+ ]
+
+ def test_good_bools(self):
+ self.assertTrue(all(
+ envparse.bool_type(val) for val in self.good_bools))
+
+ def test_bad_bools(self):
+ self.assertFalse(any(
+ envparse.bool_type(val) for val in self.bad_bools))
+
+
+if __name__ == '__main__':
+ unittest.main()