aboutsummaryrefslogtreecommitdiff
path: root/pw_protobuf_compiler/py/pw_protobuf_compiler/python_protos.py
blob: 415611a2c5d9606a21e38f2711a0f450bdb2a112 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
# Copyright 2020 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Tools for compiling and importing Python protos on the fly."""

from collections.abc import Mapping
import importlib.util
import logging
import os
from pathlib import Path
import subprocess
import shlex
import tempfile
from types import ModuleType
from typing import (
    Dict,
    Generic,
    Iterable,
    Iterator,
    List,
    NamedTuple,
    Optional,
    Set,
    Tuple,
    TypeVar,
    Union,
)

try:
    # pylint: disable=wrong-import-position
    import black

    black_mode: Optional[black.Mode] = black.Mode(string_normalization=False)

    # pylint: enable=wrong-import-position
except ImportError:
    black = None  # type: ignore
    black_mode = None

_LOG = logging.getLogger(__name__)

PathOrStr = Union[Path, str]


def _find_protoc() -> str:
    """Locates a protoc binary to use for compiling protos."""
    if 'PROTOC' in os.environ:
        return os.environ['PROTOC']

    # Fallback is assuming `protoc` is on the system PATH.
    return 'protoc'


def compile_protos(
    output_dir: PathOrStr,
    proto_files: Iterable[PathOrStr],
    includes: Iterable[PathOrStr] = (),
) -> None:
    """Compiles proto files for Python by invoking the protobuf compiler.

    Proto files not covered by one of the provided include paths will have their
    directory added as an include path.
    """
    proto_paths: List[Path] = [Path(f).resolve() for f in proto_files]
    include_paths: Set[Path] = set(Path(d).resolve() for d in includes)

    for path in proto_paths:
        if not any(include in path.parents for include in include_paths):
            include_paths.add(path.parent)

    cmd: Tuple[PathOrStr, ...] = (
        _find_protoc(),
        '--experimental_allow_proto3_optional',
        '--python_out',
        os.path.abspath(output_dir),
        *(f'-I{d}' for d in include_paths),
        *proto_paths,
    )

    _LOG.debug('%s', ' '.join(shlex.quote(str(c)) for c in cmd))
    process = subprocess.run(cmd, capture_output=True)

    if process.returncode:
        _LOG.error(
            'protoc invocation failed!\n%s\n%s',
            ' '.join(shlex.quote(str(c)) for c in cmd),
            process.stderr.decode(),
        )
        process.check_returncode()


def _import_module(name: str, path: str) -> ModuleType:
    spec = importlib.util.spec_from_file_location(name, path)
    assert spec is not None
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)  # type: ignore[union-attr]
    return module


def import_modules(directory: PathOrStr) -> Iterator:
    """Imports modules in a directory and yields them."""
    parent = os.path.dirname(directory)

    for dirpath, _, files in os.walk(directory):
        path_parts = os.path.relpath(dirpath, parent).split(os.sep)

        for file in files:
            name, ext = os.path.splitext(file)

            if ext == '.py':
                yield _import_module(
                    f'{".".join(path_parts)}.{name}',
                    os.path.join(dirpath, file),
                )


def compile_and_import(
    proto_files: Iterable[PathOrStr],
    includes: Iterable[PathOrStr] = (),
    output_dir: Optional[PathOrStr] = None,
) -> Iterator:
    """Compiles protos and imports their modules; yields the proto modules.

    Args:
      proto_files: paths to .proto files to compile
      includes: include paths to use for .proto compilation
      output_dir: where to place the generated modules; a temporary directory is
          used if omitted

    Yields:
      the generated protobuf Python modules
    """

    if output_dir:
        compile_protos(output_dir, proto_files, includes)
        yield from import_modules(output_dir)
    else:
        with tempfile.TemporaryDirectory(prefix='compiled_protos_') as tempdir:
            compile_protos(tempdir, proto_files, includes)
            yield from import_modules(tempdir)


def compile_and_import_file(
    proto_file: PathOrStr,
    includes: Iterable[PathOrStr] = (),
    output_dir: Optional[PathOrStr] = None,
):
    """Compiles and imports the module for a single .proto file."""
    return next(iter(compile_and_import([proto_file], includes, output_dir)))


def compile_and_import_strings(
    contents: Iterable[str],
    includes: Iterable[PathOrStr] = (),
    output_dir: Optional[PathOrStr] = None,
) -> Iterator:
    """Compiles protos in one or more strings."""

    if isinstance(contents, str):
        contents = [contents]

    with tempfile.TemporaryDirectory(prefix='proto_sources_') as path:
        protos = []

        for proto in contents:
            # Use a hash of the proto so the same contents map to the same file
            # name. The protobuf package complains if it seems the same contents
            # in files with different names.
            protos.append(Path(path, f'protobuf_{hash(proto):x}.proto'))
            protos[-1].write_text(proto)

        yield from compile_and_import(protos, includes, output_dir)


T = TypeVar('T')


class _NestedPackage(Generic[T]):
    """Facilitates navigating protobuf packages as attributes."""

    def __init__(self, package: str):
        self._packages: Dict[str, _NestedPackage[T]] = {}
        self._items: List[T] = []
        self._package = package

    def _add_package(self, subpackage: str, package: '_NestedPackage') -> None:
        self._packages[subpackage] = package

    def _add_item(self, item) -> None:
        if item not in self._items:  # Don't store the same item multiple times.
            self._items.append(item)

    def __getattr__(self, attr: str):
        """Look up subpackages or package members."""
        if attr in self._packages:
            return self._packages[attr]

        for item in self._items:
            if hasattr(item, attr):
                return getattr(item, attr)

        raise AttributeError(
            f'Proto package "{self._package}" does not contain "{attr}"'
        )

    def __getitem__(self, subpackage: str) -> '_NestedPackage[T]':
        """Support accessing nested packages by name."""
        result = self

        for package in subpackage.split('.'):
            result = result._packages[package]

        return result

    def __dir__(self) -> List[str]:
        """List subpackages and members of modules as attributes."""
        attributes = list(self._packages)

        for item in self._items:
            for attr, value in vars(item).items():
                # Exclude private variables and modules from dir().
                if not attr.startswith('_') and not isinstance(
                    value, ModuleType
                ):
                    attributes.append(attr)

        return attributes

    def __iter__(self) -> Iterator['_NestedPackage[T]']:
        """Iterate over nested packages."""
        return iter(self._packages.values())

    def __repr__(self) -> str:
        msg = [f'ProtoPackage({self._package!r}']

        public_members = [
            i
            for i in vars(self)
            if i not in self._packages and not i.startswith('_')
        ]
        if public_members:
            msg.append(f'members={str(public_members)}')

        if self._packages:
            msg.append(f'subpackages={str(list(self._packages))}')

        return ', '.join(msg) + ')'

    def __str__(self) -> str:
        return self._package


class Packages(NamedTuple):
    """Items in a protobuf package structure; returned from as_package."""

    items_by_package: Dict[str, List]
    packages: _NestedPackage


def as_packages(
    items: Iterable[Tuple[str, T]], packages: Optional[Packages] = None
) -> Packages:
    """Places items in a proto-style package structure navigable by attributes.

    Args:
      items: (package, item) tuples to insert into the package structure
      packages: if provided, update this Packages instead of creating a new one
    """
    if packages is None:
        packages = Packages({}, _NestedPackage(''))

    for package, item in items:
        packages.items_by_package.setdefault(package, []).append(item)

        entry = packages.packages
        subpackages = package.split('.')

        # pylint: disable=protected-access
        for i, subpackage in enumerate(subpackages, 1):
            if subpackage not in entry._packages:
                entry._add_package(
                    subpackage, _NestedPackage('.'.join(subpackages[:i]))
                )

            entry = entry._packages[subpackage]

        entry._add_item(item)
        # pylint: enable=protected-access

    return packages


PathOrModule = Union[str, Path, ModuleType]


class Library:
    """A collection of protocol buffer modules sorted by package.

    In Python, each .proto file is compiled into a Python module. The Library
    class makes it simple to navigate a collection of Python modules
    corresponding to .proto files, without relying on the location of these
    compiled modules.

    Proto messages and other types can be directly accessed by their protocol
    buffer package name. For example, the foo.bar.Baz message can be accessed
    in a Library called `protos` as:

      protos.packages.foo.bar.Baz

    A Library also provides the modules_by_package dictionary, for looking up
    the list of modules in a particular package, and the modules() generator
    for iterating over all modules.
    """

    @classmethod
    def from_paths(cls, protos: Iterable[PathOrModule]) -> 'Library':
        """Creates a Library from paths to proto files or proto modules."""
        paths: List[PathOrStr] = []
        modules: List[ModuleType] = []

        for proto in protos:
            if isinstance(proto, (Path, str)):
                paths.append(proto)
            else:
                modules.append(proto)

        if paths:
            modules += compile_and_import(paths)
        return Library(modules)

    @classmethod
    def from_strings(
        cls,
        contents: Iterable[str],
        includes: Iterable[PathOrStr] = (),
        output_dir: Optional[PathOrStr] = None,
    ) -> 'Library':
        """Creates a proto library from protos in the provided strings."""
        return cls(compile_and_import_strings(contents, includes, output_dir))

    def __init__(self, modules: Iterable[ModuleType]):
        """Constructs a Library from an iterable of modules.

        A Library can be constructed with modules dynamically compiled by
        compile_and_import. For example:

            protos = Library(compile_and_import(list_of_proto_files))
        """
        self.modules_by_package, self.packages = as_packages(
            (m.DESCRIPTOR.package, m)  # type: ignore[attr-defined]
            for m in modules
        )

    def modules(self) -> Iterable:
        """Iterates over all protobuf modules in this library."""
        for module_list in self.modules_by_package.values():
            yield from module_list

    def messages(self) -> Iterable:
        """Iterates over all protobuf messages in this library."""
        for module in self.modules():
            yield from _nested_messages(
                module, module.DESCRIPTOR.message_types_by_name
            )


def _nested_messages(scope, message_names: Iterable[str]) -> Iterator:
    for name in message_names:
        msg = getattr(scope, name)
        yield msg
        yield from _nested_messages(msg, msg.DESCRIPTOR.nested_types_by_name)


def _repr_char(char: int) -> str:
    r"""Returns an ASCII char or the \x code for non-printable values."""
    if ord(' ') <= char <= ord('~'):
        return r"\'" if chr(char) == "'" else chr(char)

    return f'\\x{char:02X}'


def bytes_repr(value: bytes) -> str:
    """Prints bytes as mixed ASCII only if at least half are printable."""
    ascii_char_count = sum(ord(' ') <= c <= ord('~') for c in value)
    if ascii_char_count >= len(value) / 2:
        contents = ''.join(_repr_char(c) for c in value)
    else:
        contents = ''.join(f'\\x{c:02X}' for c in value)

    return f"b'{contents}'"


def _field_repr(field, value) -> str:
    if field.type == field.TYPE_ENUM:
        try:
            enum = field.enum_type.values_by_number[value]
            return f'{field.enum_type.full_name}.{enum.name}'
        except KeyError:
            return repr(value)

    if field.type == field.TYPE_MESSAGE:
        return proto_repr(value)

    if field.type == field.TYPE_BYTES:
        return bytes_repr(value)

    return repr(value)


def _proto_repr(message) -> Iterator[str]:
    for field in message.DESCRIPTOR.fields:
        value = getattr(message, field.name)

        # Skip fields that are not present.
        try:
            if not message.HasField(field.name):
                continue
        except ValueError:
            # Skip default-valued fields that don't support HasField.
            if (
                field.label != field.LABEL_REPEATED
                and value == field.default_value
            ):
                continue

        if field.label == field.LABEL_REPEATED:
            if not value:
                continue

            if isinstance(value, Mapping):
                key_desc, value_desc = field.message_type.fields
                values = ', '.join(
                    f'{_field_repr(key_desc, k)}: {_field_repr(value_desc, v)}'
                    for k, v in value.items()
                )
                yield f'{field.name}={{{values}}}'
            else:
                values = ', '.join(_field_repr(field, v) for v in value)
                yield f'{field.name}=[{values}]'
        else:
            yield f'{field.name}={_field_repr(field, value)}'


def proto_repr(message, *, wrap: bool = True) -> str:
    """Creates a repr-like string for a protobuf.

    In an interactive console that imports proto objects into the namespace, the
    output of proto_repr() can be used as Python source to create a proto
    object.

    Args:
      message: The protobuf message to format
      wrap: If true and black is available, the output is wrapped according to
          PEP8 using black.
    """
    raw = f'{message.DESCRIPTOR.full_name}({", ".join(_proto_repr(message))})'

    if wrap and black is not None and black_mode is not None:
        return black.format_str(raw, mode=black_mode).strip()

    return raw