aboutsummaryrefslogtreecommitdiff
path: root/generator/nanopb_generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'generator/nanopb_generator.py')
-rwxr-xr-xgenerator/nanopb_generator.py221
1 files changed, 174 insertions, 47 deletions
diff --git a/generator/nanopb_generator.py b/generator/nanopb_generator.py
index 5b74ca1..b4f1d83 100755
--- a/generator/nanopb_generator.py
+++ b/generator/nanopb_generator.py
@@ -3,17 +3,20 @@
from __future__ import unicode_literals
'''Generate header file for nanopb from a ProtoBuf FileDescriptorSet.'''
-nanopb_version = "nanopb-0.3.9.1"
+nanopb_version = "nanopb-0.3.9.8"
import sys
import re
import codecs
+import copy
from functools import reduce
try:
# Add some dummy imports to keep packaging tools happy.
import google, distutils.util # bbfreeze seems to need these
import pkg_resources # pyinstaller / protobuf 2.5 seem to need these
+ import proto.nanopb_pb2 as nanopb_pb2 # pyinstaller seems to need this
+ import proto.plugin_pb2 as plugin_pb2
except:
# Don't care, we will error out later if it is actually important.
pass
@@ -31,8 +34,7 @@ except:
raise
try:
- import proto.nanopb_pb2 as nanopb_pb2
- import proto.plugin_pb2 as plugin_pb2
+ from .proto import nanopb_pb2, plugin_pb2
except TypeError:
sys.stderr.write('''
****************************************************************************
@@ -44,9 +46,16 @@ except TypeError:
*** which protoc ***
*** protoc --version ***
*** python -c 'import google.protobuf; print(google.protobuf.__file__)' ***
+ *** If you are not able to find the python protobuf version using the ***
+ *** above command, use this command. ***
+ *** pip freeze | grep -i protobuf ***
****************************************************************************
''' + '\n')
raise
+except (ValueError, SystemError, ImportError):
+ # Probably invoked directly instead of via installed scripts.
+ import proto.nanopb_pb2 as nanopb_pb2
+ import proto.plugin_pb2 as plugin_pb2
except:
sys.stderr.write('''
********************************************************************
@@ -95,11 +104,14 @@ try:
except NameError:
strtypes = (str, )
+
class Names:
'''Keeps a set of nested names and formats them to C identifier.'''
def __init__(self, parts = ()):
if isinstance(parts, Names):
parts = parts.parts
+ elif isinstance(parts, strtypes):
+ parts = (parts,)
self.parts = tuple(parts)
def __str__(self):
@@ -108,6 +120,8 @@ class Names:
def __add__(self, other):
if isinstance(other, strtypes):
return Names(self.parts + (other,))
+ elif isinstance(other, Names):
+ return Names(self.parts + other.parts)
elif isinstance(other, tuple):
return Names(self.parts + other)
else:
@@ -183,12 +197,15 @@ class Enum:
'''desc is EnumDescriptorProto'''
self.options = enum_options
- self.names = names + desc.name
+ self.names = names
+
+ # by definition, `names` include this enum's name
+ base_name = Names(names.parts[:-1])
if enum_options.long_names:
- self.values = [(self.names + x.name, x.number) for x in desc.value]
- else:
self.values = [(names + x.name, x.number) for x in desc.value]
+ else:
+ self.values = [(base_name + x.name, x.number) for x in desc.value]
self.value_longnames = [self.names + x.name for x in desc.value]
self.packed = enum_options.packed_enum
@@ -212,9 +229,12 @@ class Enum:
result += ' %s;' % self.names
- result += '\n#define _%s_MIN %s' % (self.names, self.values[0][0])
- result += '\n#define _%s_MAX %s' % (self.names, self.values[-1][0])
- result += '\n#define _%s_ARRAYSIZE ((%s)(%s+1))' % (self.names, self.names, self.values[-1][0])
+ # sort the enum by value
+ sorted_values = sorted(self.values, key = lambda x: (x[1], x[0]))
+
+ result += '\n#define _%s_MIN %s' % (self.names, sorted_values[0][0])
+ result += '\n#define _%s_MAX %s' % (self.names, sorted_values[-1][0])
+ result += '\n#define _%s_ARRAYSIZE ((%s)(%s+1))' % (self.names, self.names, sorted_values[-1][0])
if not self.options.long_names:
# Define the long names always so that enum value references
@@ -336,7 +356,7 @@ class Field:
if field_options.type == nanopb_pb2.FT_STATIC and not can_be_static:
raise Exception("Field '%s' is defined as static, but max_size or "
"max_count is not given." % self.name)
-
+
if field_options.fixed_count and self.max_count is None:
raise Exception("Field '%s' is defined as fixed count, "
"but max_count is not given." % self.name)
@@ -606,7 +626,15 @@ class Field:
'''Determine if this field needs 16bit or 32bit pb_field_t structure to compile properly.
Returns numeric value or a C-expression for assert.'''
check = []
- if self.pbtype == 'MESSAGE' and self.allocation == 'STATIC':
+
+ need_check = False
+
+ if self.pbtype == 'BYTES' and self.allocation == 'STATIC' and self.max_size > 251:
+ need_check = True
+ elif self.pbtype == 'MESSAGE' and self.allocation == 'STATIC':
+ need_check = True
+
+ if need_check:
if self.rules == 'REPEATED':
check.append('pb_membersize(%s, %s[0])' % (self.struct_name, self.name))
elif self.rules == 'ONEOF':
@@ -616,9 +644,6 @@ class Field:
check.append('pb_membersize(%s, %s.%s)' % (self.struct_name, self.union_name, self.name))
else:
check.append('pb_membersize(%s, %s)' % (self.struct_name, self.name))
- elif self.pbtype == 'BYTES' and self.allocation == 'STATIC':
- if self.max_size > 251:
- check.append('pb_membersize(%s, %s)' % (self.struct_name, self.name))
return FieldMaxSize([self.tag, self.max_size, self.max_count],
check,
@@ -640,6 +665,13 @@ class Field:
if encsize is not None:
# Include submessage length prefix
encsize += varint_max_size(encsize.upperlimit())
+ else:
+ my_msg = dependencies.get(str(self.struct_name))
+ if my_msg and submsg.protofile == my_msg.protofile:
+ # The dependency is from the same file and size cannot be
+ # determined for it, thus we know it will not be possible
+ # in runtime either.
+ return None
if encsize is None:
# Submessage or its size cannot be found.
@@ -719,8 +751,8 @@ class ExtensionRange(Field):
return EncodedSize(0)
class ExtensionField(Field):
- def __init__(self, struct_name, desc, field_options):
- self.fullname = struct_name + desc.name
+ def __init__(self, fullname, desc, field_options):
+ self.fullname = fullname
self.extendee_name = names_from_type_name(desc.extendee)
Field.__init__(self, self.fullname + 'struct', desc, field_options)
@@ -845,24 +877,33 @@ class OneOf(Field):
def encoded_size(self, dependencies):
'''Returns the size of the largest oneof field.'''
- largest = EncodedSize(0)
- symbols = set()
+ largest = 0
+ symbols = []
for f in self.fields:
size = EncodedSize(f.encoded_size(dependencies))
- if size.value is None:
+ if size is None or size.value is None:
return None
elif size.symbols:
- symbols.add(EncodedSize(f.submsgname + 'size').symbols[0])
- elif size.value > largest.value:
- largest = size
+ symbols.append((f.tag, size.symbols[0]))
+ elif size.value > largest:
+ largest = size.value
if not symbols:
+ # Simple case, all sizes were known at generator time
return largest
- symbols = list(symbols)
- symbols.append(str(largest))
- max_size = lambda x, y: '({0} > {1} ? {0} : {1})'.format(x, y)
- return reduce(max_size, symbols)
+ if largest > 0:
+ # Some sizes were known, some were not
+ symbols.insert(0, (0, largest))
+
+ if len(symbols) == 1:
+ # Only one symbol was needed
+ return EncodedSize(5, [symbols[0][1]])
+ else:
+ # Use sizeof(union{}) construct to find the maximum size of
+ # submessages.
+ union_def = ' '.join('char f%d[%s];' % s for s in symbols)
+ return EncodedSize(5, ['sizeof(union{%s})' % union_def])
# ---------------------------------------------------------------------------
# Generation of messages (structures)
@@ -969,6 +1010,15 @@ class Message:
result += default + '\n'
return result
+ def all_fields(self):
+ '''Iterate over all fields in this message, including nested OneOfs.'''
+ for f in self.fields:
+ if isinstance(f, OneOf):
+ for f2 in f.fields:
+ yield f2
+ else:
+ yield f
+
def count_required_fields(self):
'''Returns number of required fields inside this message'''
count = 0
@@ -1021,7 +1071,7 @@ class Message:
# Processing of entire .proto files
# ---------------------------------------------------------------------------
-def iterate_messages(desc, names = Names()):
+def iterate_messages(desc, flatten = False, names = Names()):
'''Recursively find all messages. For each, yield name, DescriptorProto.'''
if hasattr(desc, 'message_type'):
submsgs = desc.message_type
@@ -1030,19 +1080,22 @@ def iterate_messages(desc, names = Names()):
for submsg in submsgs:
sub_names = names + submsg.name
- yield sub_names, submsg
+ if flatten:
+ yield Names(submsg.name), submsg
+ else:
+ yield sub_names, submsg
- for x in iterate_messages(submsg, sub_names):
+ for x in iterate_messages(submsg, flatten, sub_names):
yield x
-def iterate_extensions(desc, names = Names()):
+def iterate_extensions(desc, flatten = False, names = Names()):
'''Recursively find all extensions.
For each, yield name, FieldDescriptorProto.
'''
for extension in desc.extension:
yield names, extension
- for subname, subdesc in iterate_messages(desc, names):
+ for subname, subdesc in iterate_messages(desc, flatten, names):
for extension in subdesc.extension:
yield subname, extension
@@ -1104,43 +1157,79 @@ class ProtoFile:
self.messages = []
self.extensions = []
+ mangle_names = self.file_options.mangle_names
+ flatten = mangle_names == nanopb_pb2.M_FLATTEN
+ strip_prefix = None
+ if mangle_names == nanopb_pb2.M_STRIP_PACKAGE:
+ strip_prefix = "." + self.fdesc.package
+
+ def create_name(names):
+ if mangle_names == nanopb_pb2.M_NONE:
+ return base_name + names
+ elif mangle_names == nanopb_pb2.M_STRIP_PACKAGE:
+ return Names(names)
+ else:
+ single_name = names
+ if isinstance(names, Names):
+ single_name = names.parts[-1]
+ return Names(single_name)
+
+ def mangle_field_typename(typename):
+ if mangle_names == nanopb_pb2.M_FLATTEN:
+ return "." + typename.split(".")[-1]
+ elif strip_prefix is not None and typename.startswith(strip_prefix):
+ return typename[len(strip_prefix):]
+ else:
+ return typename
+
if self.fdesc.package:
base_name = Names(self.fdesc.package.split('.'))
else:
base_name = Names()
for enum in self.fdesc.enum_type:
- enum_options = get_nanopb_suboptions(enum, self.file_options, base_name + enum.name)
- self.enums.append(Enum(base_name, enum, enum_options))
+ name = create_name(enum.name)
+ enum_options = get_nanopb_suboptions(enum, self.file_options, name)
+ self.enums.append(Enum(name, enum, enum_options))
- for names, message in iterate_messages(self.fdesc, base_name):
- message_options = get_nanopb_suboptions(message, self.file_options, names)
+ for names, message in iterate_messages(self.fdesc, flatten):
+ name = create_name(names)
+ message_options = get_nanopb_suboptions(message, self.file_options, name)
if message_options.skip_message:
continue
- self.messages.append(Message(names, message, message_options))
+ message = copy.deepcopy(message)
+ for field in message.field:
+ if field.type in (FieldD.TYPE_MESSAGE, FieldD.TYPE_ENUM):
+ field.type_name = mangle_field_typename(field.type_name)
+
+ self.messages.append(Message(name, message, message_options))
for enum in message.enum_type:
- enum_options = get_nanopb_suboptions(enum, message_options, names + enum.name)
- self.enums.append(Enum(names, enum, enum_options))
+ name = create_name(names + enum.name)
+ enum_options = get_nanopb_suboptions(enum, message_options, name)
+ self.enums.append(Enum(name, enum, enum_options))
- for names, extension in iterate_extensions(self.fdesc, base_name):
- field_options = get_nanopb_suboptions(extension, self.file_options, names + extension.name)
+ for names, extension in iterate_extensions(self.fdesc, flatten):
+ name = create_name(names + extension.name)
+ field_options = get_nanopb_suboptions(extension, self.file_options, name)
if field_options.type != nanopb_pb2.FT_IGNORE:
- self.extensions.append(ExtensionField(names, extension, field_options))
+ self.extensions.append(ExtensionField(name, extension, field_options))
def add_dependency(self, other):
for enum in other.enums:
self.dependencies[str(enum.names)] = enum
+ enum.protofile = other
for msg in other.messages:
self.dependencies[str(msg.name)] = msg
+ msg.protofile = other
# Fix field default values where enum short names are used.
for enum in other.enums:
if not enum.options.long_names:
for message in self.messages:
- for field in message.fields:
+ for field in message.all_fields():
if field.default in enum.value_longnames:
idx = enum.value_longnames.index(field.default)
field.default = enum.values[idx][0]
@@ -1149,7 +1238,7 @@ class ProtoFile:
for enum in other.enums:
if not enum.has_negative():
for message in self.messages:
- for field in message.fields:
+ for field in message.all_fields():
if field.pbtype == 'ENUM' and field.ctype == enum.names:
field.pbtype = 'UENUM'
@@ -1350,7 +1439,7 @@ class ProtoFile:
msgs = '_'.join(str(n) for n in checks_msgnames)
yield '/* If you get an error here, it means that you need to define PB_FIELD_32BIT\n'
yield ' * compile-time option. You can do that in pb.h or on compiler command line.\n'
- yield ' * \n'
+ yield ' *\n'
yield ' * The reason you need to do this is that some of your messages contain tag\n'
yield ' * numbers or field sizes that are larger than what can fit in 8 or 16 bit\n'
yield ' * field descriptors.\n'
@@ -1367,7 +1456,7 @@ class ProtoFile:
msgs = '_'.join(str(n) for n in checks_msgnames)
yield '/* If you get an error here, it means that you need to define PB_FIELD_16BIT\n'
yield ' * compile-time option. You can do that in pb.h or on compiler command line.\n'
- yield ' * \n'
+ yield ' *\n'
yield ' * The reason you need to do this is that some of your messages contain tag\n'
yield ' * numbers or field sizes that are larger than what can fit in the default\n'
yield ' * 8 bit descriptors.\n'
@@ -1378,7 +1467,7 @@ class ProtoFile:
# Add check for sizeof(double)
has_double = False
for msg in self.messages:
- for field in msg.fields:
+ for field in msg.all_fields():
if field.ctype == 'double':
has_double = True
@@ -1490,6 +1579,8 @@ optparser = OptionParser(
usage = "Usage: nanopb_generator.py [options] file.pb ...",
epilog = "Compile file.pb from file.proto by: 'protoc -ofile.pb file.proto'. " +
"Output will be written to file.pb.h and file.pb.c.")
+optparser.add_option("--version", dest="version", action="store_true",
+ help="Show version info and exit")
optparser.add_option("-x", dest="exclude", metavar="FILE", action="append", default=[],
help="Exclude file from generated #include list.")
optparser.add_option("-e", "--extension", dest="extension", metavar="EXTENSION", default=".pb",
@@ -1512,6 +1603,10 @@ optparser.add_option("-Q", "--generated-include-format", dest="genformat",
optparser.add_option("-L", "--library-include-format", dest="libformat",
metavar="FORMAT", default='#include <%s>\n',
help="Set format string to use for including the nanopb pb.h header. [default: %default]")
+optparser.add_option("--strip-path", dest="strip_path", action="store_true", default=True,
+ help="Strip directory path from #included .pb.h file name [default: %default]")
+optparser.add_option("--no-strip-path", dest="strip_path", action="store_false",
+ help="Opposite of --strip-path")
optparser.add_option("-T", "--no-timestamp", dest="notimestamp", action="store_true", default=False,
help="Don't add timestamp to .pb.h and .pb.c preambles")
optparser.add_option("-q", "--quiet", dest="quiet", action="store_true", default=False,
@@ -1589,7 +1684,11 @@ def process_file(filename, fdesc, options, other_files = {}):
noext = os.path.splitext(filename)[0]
headername = noext + options.extension + options.header_extension
sourcename = noext + options.extension + options.source_extension
- headerbasename = os.path.basename(headername)
+
+ if options.strip_path:
+ headerbasename = os.path.basename(headername)
+ else:
+ headerbasename = headername
# List of .proto files that should not be included in the C header file
# even if they are mentioned in the source .proto.
@@ -1615,6 +1714,10 @@ def main_cli():
options, filenames = optparser.parse_args()
+ if options.version:
+ print(nanopb_version)
+ sys.exit(0)
+
if not filenames:
optparser.print_help()
sys.exit(1)
@@ -1628,6 +1731,7 @@ def main_cli():
sys.exit(1)
if options.verbose:
+ sys.stderr.write("Nanopb version %s\n" % nanopb_version)
sys.stderr.write('Google Python protobuf library imported from %s, version %s\n'
% (google.protobuf.__file__, google.protobuf.__version__))
@@ -1672,11 +1776,34 @@ def main_plugin():
import shlex
args = shlex.split(params)
+
+ if len(args) == 1 and ',' in args[0]:
+ # For compatibility with other protoc plugins, support options
+ # separated by comma.
+ lex = shlex.shlex(params)
+ lex.whitespace_split = True
+ lex.whitespace = ','
+ args = list(lex)
+
+ optparser.usage = "Usage: protoc --nanopb_out=[options][,more_options]:outdir file.proto"
+ optparser.epilog = "Output will be written to file.pb.h and file.pb.c."
+
+ if '-h' in args or '--help' in args:
+ # By default optparser prints help to stdout, which doesn't work for
+ # protoc plugins.
+ optparser.print_help(sys.stderr)
+ sys.exit(1)
+
options, dummy = optparser.parse_args(args)
+ if options.version:
+ sys.stderr.write('%s\n' % (nanopb_version))
+ sys.exit(0)
+
Globals.verbose_options = options.verbose
if options.verbose:
+ sys.stderr.write("Nanopb version %s\n" % nanopb_version)
sys.stderr.write('Google Python protobuf library imported from %s, version %s\n'
% (google.protobuf.__file__, google.protobuf.__version__))