diff options
Diffstat (limited to 'generator/nanopb_generator.py')
-rwxr-xr-x | generator/nanopb_generator.py | 221 |
1 files changed, 174 insertions, 47 deletions
diff --git a/generator/nanopb_generator.py b/generator/nanopb_generator.py index 5b74ca1..b4f1d83 100755 --- a/generator/nanopb_generator.py +++ b/generator/nanopb_generator.py @@ -3,17 +3,20 @@ from __future__ import unicode_literals '''Generate header file for nanopb from a ProtoBuf FileDescriptorSet.''' -nanopb_version = "nanopb-0.3.9.1" +nanopb_version = "nanopb-0.3.9.8" import sys import re import codecs +import copy from functools import reduce try: # Add some dummy imports to keep packaging tools happy. import google, distutils.util # bbfreeze seems to need these import pkg_resources # pyinstaller / protobuf 2.5 seem to need these + import proto.nanopb_pb2 as nanopb_pb2 # pyinstaller seems to need this + import proto.plugin_pb2 as plugin_pb2 except: # Don't care, we will error out later if it is actually important. pass @@ -31,8 +34,7 @@ except: raise try: - import proto.nanopb_pb2 as nanopb_pb2 - import proto.plugin_pb2 as plugin_pb2 + from .proto import nanopb_pb2, plugin_pb2 except TypeError: sys.stderr.write(''' **************************************************************************** @@ -44,9 +46,16 @@ except TypeError: *** which protoc *** *** protoc --version *** *** python -c 'import google.protobuf; print(google.protobuf.__file__)' *** + *** If you are not able to find the python protobuf version using the *** + *** above command, use this command. *** + *** pip freeze | grep -i protobuf *** **************************************************************************** ''' + '\n') raise +except (ValueError, SystemError, ImportError): + # Probably invoked directly instead of via installed scripts. + import proto.nanopb_pb2 as nanopb_pb2 + import proto.plugin_pb2 as plugin_pb2 except: sys.stderr.write(''' ******************************************************************** @@ -95,11 +104,14 @@ try: except NameError: strtypes = (str, ) + class Names: '''Keeps a set of nested names and formats them to C identifier.''' def __init__(self, parts = ()): if isinstance(parts, Names): parts = parts.parts + elif isinstance(parts, strtypes): + parts = (parts,) self.parts = tuple(parts) def __str__(self): @@ -108,6 +120,8 @@ class Names: def __add__(self, other): if isinstance(other, strtypes): return Names(self.parts + (other,)) + elif isinstance(other, Names): + return Names(self.parts + other.parts) elif isinstance(other, tuple): return Names(self.parts + other) else: @@ -183,12 +197,15 @@ class Enum: '''desc is EnumDescriptorProto''' self.options = enum_options - self.names = names + desc.name + self.names = names + + # by definition, `names` include this enum's name + base_name = Names(names.parts[:-1]) if enum_options.long_names: - self.values = [(self.names + x.name, x.number) for x in desc.value] - else: self.values = [(names + x.name, x.number) for x in desc.value] + else: + self.values = [(base_name + x.name, x.number) for x in desc.value] self.value_longnames = [self.names + x.name for x in desc.value] self.packed = enum_options.packed_enum @@ -212,9 +229,12 @@ class Enum: result += ' %s;' % self.names - result += '\n#define _%s_MIN %s' % (self.names, self.values[0][0]) - result += '\n#define _%s_MAX %s' % (self.names, self.values[-1][0]) - result += '\n#define _%s_ARRAYSIZE ((%s)(%s+1))' % (self.names, self.names, self.values[-1][0]) + # sort the enum by value + sorted_values = sorted(self.values, key = lambda x: (x[1], x[0])) + + result += '\n#define _%s_MIN %s' % (self.names, sorted_values[0][0]) + result += '\n#define _%s_MAX %s' % (self.names, sorted_values[-1][0]) + result += '\n#define _%s_ARRAYSIZE ((%s)(%s+1))' % (self.names, self.names, sorted_values[-1][0]) if not self.options.long_names: # Define the long names always so that enum value references @@ -336,7 +356,7 @@ class Field: if field_options.type == nanopb_pb2.FT_STATIC and not can_be_static: raise Exception("Field '%s' is defined as static, but max_size or " "max_count is not given." % self.name) - + if field_options.fixed_count and self.max_count is None: raise Exception("Field '%s' is defined as fixed count, " "but max_count is not given." % self.name) @@ -606,7 +626,15 @@ class Field: '''Determine if this field needs 16bit or 32bit pb_field_t structure to compile properly. Returns numeric value or a C-expression for assert.''' check = [] - if self.pbtype == 'MESSAGE' and self.allocation == 'STATIC': + + need_check = False + + if self.pbtype == 'BYTES' and self.allocation == 'STATIC' and self.max_size > 251: + need_check = True + elif self.pbtype == 'MESSAGE' and self.allocation == 'STATIC': + need_check = True + + if need_check: if self.rules == 'REPEATED': check.append('pb_membersize(%s, %s[0])' % (self.struct_name, self.name)) elif self.rules == 'ONEOF': @@ -616,9 +644,6 @@ class Field: check.append('pb_membersize(%s, %s.%s)' % (self.struct_name, self.union_name, self.name)) else: check.append('pb_membersize(%s, %s)' % (self.struct_name, self.name)) - elif self.pbtype == 'BYTES' and self.allocation == 'STATIC': - if self.max_size > 251: - check.append('pb_membersize(%s, %s)' % (self.struct_name, self.name)) return FieldMaxSize([self.tag, self.max_size, self.max_count], check, @@ -640,6 +665,13 @@ class Field: if encsize is not None: # Include submessage length prefix encsize += varint_max_size(encsize.upperlimit()) + else: + my_msg = dependencies.get(str(self.struct_name)) + if my_msg and submsg.protofile == my_msg.protofile: + # The dependency is from the same file and size cannot be + # determined for it, thus we know it will not be possible + # in runtime either. + return None if encsize is None: # Submessage or its size cannot be found. @@ -719,8 +751,8 @@ class ExtensionRange(Field): return EncodedSize(0) class ExtensionField(Field): - def __init__(self, struct_name, desc, field_options): - self.fullname = struct_name + desc.name + def __init__(self, fullname, desc, field_options): + self.fullname = fullname self.extendee_name = names_from_type_name(desc.extendee) Field.__init__(self, self.fullname + 'struct', desc, field_options) @@ -845,24 +877,33 @@ class OneOf(Field): def encoded_size(self, dependencies): '''Returns the size of the largest oneof field.''' - largest = EncodedSize(0) - symbols = set() + largest = 0 + symbols = [] for f in self.fields: size = EncodedSize(f.encoded_size(dependencies)) - if size.value is None: + if size is None or size.value is None: return None elif size.symbols: - symbols.add(EncodedSize(f.submsgname + 'size').symbols[0]) - elif size.value > largest.value: - largest = size + symbols.append((f.tag, size.symbols[0])) + elif size.value > largest: + largest = size.value if not symbols: + # Simple case, all sizes were known at generator time return largest - symbols = list(symbols) - symbols.append(str(largest)) - max_size = lambda x, y: '({0} > {1} ? {0} : {1})'.format(x, y) - return reduce(max_size, symbols) + if largest > 0: + # Some sizes were known, some were not + symbols.insert(0, (0, largest)) + + if len(symbols) == 1: + # Only one symbol was needed + return EncodedSize(5, [symbols[0][1]]) + else: + # Use sizeof(union{}) construct to find the maximum size of + # submessages. + union_def = ' '.join('char f%d[%s];' % s for s in symbols) + return EncodedSize(5, ['sizeof(union{%s})' % union_def]) # --------------------------------------------------------------------------- # Generation of messages (structures) @@ -969,6 +1010,15 @@ class Message: result += default + '\n' return result + def all_fields(self): + '''Iterate over all fields in this message, including nested OneOfs.''' + for f in self.fields: + if isinstance(f, OneOf): + for f2 in f.fields: + yield f2 + else: + yield f + def count_required_fields(self): '''Returns number of required fields inside this message''' count = 0 @@ -1021,7 +1071,7 @@ class Message: # Processing of entire .proto files # --------------------------------------------------------------------------- -def iterate_messages(desc, names = Names()): +def iterate_messages(desc, flatten = False, names = Names()): '''Recursively find all messages. For each, yield name, DescriptorProto.''' if hasattr(desc, 'message_type'): submsgs = desc.message_type @@ -1030,19 +1080,22 @@ def iterate_messages(desc, names = Names()): for submsg in submsgs: sub_names = names + submsg.name - yield sub_names, submsg + if flatten: + yield Names(submsg.name), submsg + else: + yield sub_names, submsg - for x in iterate_messages(submsg, sub_names): + for x in iterate_messages(submsg, flatten, sub_names): yield x -def iterate_extensions(desc, names = Names()): +def iterate_extensions(desc, flatten = False, names = Names()): '''Recursively find all extensions. For each, yield name, FieldDescriptorProto. ''' for extension in desc.extension: yield names, extension - for subname, subdesc in iterate_messages(desc, names): + for subname, subdesc in iterate_messages(desc, flatten, names): for extension in subdesc.extension: yield subname, extension @@ -1104,43 +1157,79 @@ class ProtoFile: self.messages = [] self.extensions = [] + mangle_names = self.file_options.mangle_names + flatten = mangle_names == nanopb_pb2.M_FLATTEN + strip_prefix = None + if mangle_names == nanopb_pb2.M_STRIP_PACKAGE: + strip_prefix = "." + self.fdesc.package + + def create_name(names): + if mangle_names == nanopb_pb2.M_NONE: + return base_name + names + elif mangle_names == nanopb_pb2.M_STRIP_PACKAGE: + return Names(names) + else: + single_name = names + if isinstance(names, Names): + single_name = names.parts[-1] + return Names(single_name) + + def mangle_field_typename(typename): + if mangle_names == nanopb_pb2.M_FLATTEN: + return "." + typename.split(".")[-1] + elif strip_prefix is not None and typename.startswith(strip_prefix): + return typename[len(strip_prefix):] + else: + return typename + if self.fdesc.package: base_name = Names(self.fdesc.package.split('.')) else: base_name = Names() for enum in self.fdesc.enum_type: - enum_options = get_nanopb_suboptions(enum, self.file_options, base_name + enum.name) - self.enums.append(Enum(base_name, enum, enum_options)) + name = create_name(enum.name) + enum_options = get_nanopb_suboptions(enum, self.file_options, name) + self.enums.append(Enum(name, enum, enum_options)) - for names, message in iterate_messages(self.fdesc, base_name): - message_options = get_nanopb_suboptions(message, self.file_options, names) + for names, message in iterate_messages(self.fdesc, flatten): + name = create_name(names) + message_options = get_nanopb_suboptions(message, self.file_options, name) if message_options.skip_message: continue - self.messages.append(Message(names, message, message_options)) + message = copy.deepcopy(message) + for field in message.field: + if field.type in (FieldD.TYPE_MESSAGE, FieldD.TYPE_ENUM): + field.type_name = mangle_field_typename(field.type_name) + + self.messages.append(Message(name, message, message_options)) for enum in message.enum_type: - enum_options = get_nanopb_suboptions(enum, message_options, names + enum.name) - self.enums.append(Enum(names, enum, enum_options)) + name = create_name(names + enum.name) + enum_options = get_nanopb_suboptions(enum, message_options, name) + self.enums.append(Enum(name, enum, enum_options)) - for names, extension in iterate_extensions(self.fdesc, base_name): - field_options = get_nanopb_suboptions(extension, self.file_options, names + extension.name) + for names, extension in iterate_extensions(self.fdesc, flatten): + name = create_name(names + extension.name) + field_options = get_nanopb_suboptions(extension, self.file_options, name) if field_options.type != nanopb_pb2.FT_IGNORE: - self.extensions.append(ExtensionField(names, extension, field_options)) + self.extensions.append(ExtensionField(name, extension, field_options)) def add_dependency(self, other): for enum in other.enums: self.dependencies[str(enum.names)] = enum + enum.protofile = other for msg in other.messages: self.dependencies[str(msg.name)] = msg + msg.protofile = other # Fix field default values where enum short names are used. for enum in other.enums: if not enum.options.long_names: for message in self.messages: - for field in message.fields: + for field in message.all_fields(): if field.default in enum.value_longnames: idx = enum.value_longnames.index(field.default) field.default = enum.values[idx][0] @@ -1149,7 +1238,7 @@ class ProtoFile: for enum in other.enums: if not enum.has_negative(): for message in self.messages: - for field in message.fields: + for field in message.all_fields(): if field.pbtype == 'ENUM' and field.ctype == enum.names: field.pbtype = 'UENUM' @@ -1350,7 +1439,7 @@ class ProtoFile: msgs = '_'.join(str(n) for n in checks_msgnames) yield '/* If you get an error here, it means that you need to define PB_FIELD_32BIT\n' yield ' * compile-time option. You can do that in pb.h or on compiler command line.\n' - yield ' * \n' + yield ' *\n' yield ' * The reason you need to do this is that some of your messages contain tag\n' yield ' * numbers or field sizes that are larger than what can fit in 8 or 16 bit\n' yield ' * field descriptors.\n' @@ -1367,7 +1456,7 @@ class ProtoFile: msgs = '_'.join(str(n) for n in checks_msgnames) yield '/* If you get an error here, it means that you need to define PB_FIELD_16BIT\n' yield ' * compile-time option. You can do that in pb.h or on compiler command line.\n' - yield ' * \n' + yield ' *\n' yield ' * The reason you need to do this is that some of your messages contain tag\n' yield ' * numbers or field sizes that are larger than what can fit in the default\n' yield ' * 8 bit descriptors.\n' @@ -1378,7 +1467,7 @@ class ProtoFile: # Add check for sizeof(double) has_double = False for msg in self.messages: - for field in msg.fields: + for field in msg.all_fields(): if field.ctype == 'double': has_double = True @@ -1490,6 +1579,8 @@ optparser = OptionParser( usage = "Usage: nanopb_generator.py [options] file.pb ...", epilog = "Compile file.pb from file.proto by: 'protoc -ofile.pb file.proto'. " + "Output will be written to file.pb.h and file.pb.c.") +optparser.add_option("--version", dest="version", action="store_true", + help="Show version info and exit") optparser.add_option("-x", dest="exclude", metavar="FILE", action="append", default=[], help="Exclude file from generated #include list.") optparser.add_option("-e", "--extension", dest="extension", metavar="EXTENSION", default=".pb", @@ -1512,6 +1603,10 @@ optparser.add_option("-Q", "--generated-include-format", dest="genformat", optparser.add_option("-L", "--library-include-format", dest="libformat", metavar="FORMAT", default='#include <%s>\n', help="Set format string to use for including the nanopb pb.h header. [default: %default]") +optparser.add_option("--strip-path", dest="strip_path", action="store_true", default=True, + help="Strip directory path from #included .pb.h file name [default: %default]") +optparser.add_option("--no-strip-path", dest="strip_path", action="store_false", + help="Opposite of --strip-path") optparser.add_option("-T", "--no-timestamp", dest="notimestamp", action="store_true", default=False, help="Don't add timestamp to .pb.h and .pb.c preambles") optparser.add_option("-q", "--quiet", dest="quiet", action="store_true", default=False, @@ -1589,7 +1684,11 @@ def process_file(filename, fdesc, options, other_files = {}): noext = os.path.splitext(filename)[0] headername = noext + options.extension + options.header_extension sourcename = noext + options.extension + options.source_extension - headerbasename = os.path.basename(headername) + + if options.strip_path: + headerbasename = os.path.basename(headername) + else: + headerbasename = headername # List of .proto files that should not be included in the C header file # even if they are mentioned in the source .proto. @@ -1615,6 +1714,10 @@ def main_cli(): options, filenames = optparser.parse_args() + if options.version: + print(nanopb_version) + sys.exit(0) + if not filenames: optparser.print_help() sys.exit(1) @@ -1628,6 +1731,7 @@ def main_cli(): sys.exit(1) if options.verbose: + sys.stderr.write("Nanopb version %s\n" % nanopb_version) sys.stderr.write('Google Python protobuf library imported from %s, version %s\n' % (google.protobuf.__file__, google.protobuf.__version__)) @@ -1672,11 +1776,34 @@ def main_plugin(): import shlex args = shlex.split(params) + + if len(args) == 1 and ',' in args[0]: + # For compatibility with other protoc plugins, support options + # separated by comma. + lex = shlex.shlex(params) + lex.whitespace_split = True + lex.whitespace = ',' + args = list(lex) + + optparser.usage = "Usage: protoc --nanopb_out=[options][,more_options]:outdir file.proto" + optparser.epilog = "Output will be written to file.pb.h and file.pb.c." + + if '-h' in args or '--help' in args: + # By default optparser prints help to stdout, which doesn't work for + # protoc plugins. + optparser.print_help(sys.stderr) + sys.exit(1) + options, dummy = optparser.parse_args(args) + if options.version: + sys.stderr.write('%s\n' % (nanopb_version)) + sys.exit(0) + Globals.verbose_options = options.verbose if options.verbose: + sys.stderr.write("Nanopb version %s\n" % nanopb_version) sys.stderr.write('Google Python protobuf library imported from %s, version %s\n' % (google.protobuf.__file__, google.protobuf.__version__)) |