diff options
Diffstat (limited to 'generator/google/protobuf/message_factory.py')
-rw-r--r-- | generator/google/protobuf/message_factory.py | 106 |
1 files changed, 36 insertions, 70 deletions
diff --git a/generator/google/protobuf/message_factory.py b/generator/google/protobuf/message_factory.py index 1b059d1..36e2fef 100644 --- a/generator/google/protobuf/message_factory.py +++ b/generator/google/protobuf/message_factory.py @@ -1,6 +1,6 @@ # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. -# https://developers.google.com/protocol-buffers/ +# http://code.google.com/p/protobuf/ # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -28,17 +28,11 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -"""Provides a factory class for generating dynamic messages. - -The easiest way to use this class is if you have access to the FileDescriptor -protos containing the messages you want to create you can just do the following: - -message_classes = message_factory.GetMessages(iterable_of_file_descriptors) -my_proto_instance = message_classes['some.proto.package.MessageName']() -""" +"""Provides a factory class for generating dynamic messages.""" __author__ = 'matthewtoia@google.com (Matt Toia)' +from google.protobuf import descriptor_database from google.protobuf import descriptor_pool from google.protobuf import message from google.protobuf import reflection @@ -47,11 +41,8 @@ from google.protobuf import reflection class MessageFactory(object): """Factory for creating Proto2 messages from descriptors in a pool.""" - def __init__(self, pool=None): + def __init__(self): """Initializes a new factory.""" - self.pool = pool or descriptor_pool.DescriptorPool() - - # local cache of all classes built from protobuf descriptors self._classes = {} def GetPrototype(self, descriptor): @@ -66,68 +57,21 @@ class MessageFactory(object): Returns: A class describing the passed in descriptor. """ + if descriptor.full_name not in self._classes: - descriptor_name = descriptor.name - if str is bytes: # PY2 - descriptor_name = descriptor.name.encode('ascii', 'ignore') result_class = reflection.GeneratedProtocolMessageType( - descriptor_name, + descriptor.name.encode('ascii', 'ignore'), (message.Message,), - {'DESCRIPTOR': descriptor, '__module__': None}) - # If module not set, it wrongly points to the reflection.py module. + {'DESCRIPTOR': descriptor}) self._classes[descriptor.full_name] = result_class for field in descriptor.fields: if field.message_type: self.GetPrototype(field.message_type) - for extension in result_class.DESCRIPTOR.extensions: - if extension.containing_type.full_name not in self._classes: - self.GetPrototype(extension.containing_type) - extended_class = self._classes[extension.containing_type.full_name] - extended_class.RegisterExtension(extension) return self._classes[descriptor.full_name] - def GetMessages(self, files): - """Gets all the messages from a specified file. - - This will find and resolve dependencies, failing if the descriptor - pool cannot satisfy them. - - Args: - files: The file names to extract messages from. - - Returns: - A dictionary mapping proto names to the message classes. This will include - any dependent messages as well as any messages defined in the same file as - a specified message. - """ - result = {} - for file_name in files: - file_desc = self.pool.FindFileByName(file_name) - for name, msg in file_desc.message_types_by_name.items(): - if file_desc.package: - full_name = '.'.join([file_desc.package, name]) - else: - full_name = msg.name - result[full_name] = self.GetPrototype( - self.pool.FindMessageTypeByName(full_name)) - - # While the extension FieldDescriptors are created by the descriptor pool, - # the python classes created in the factory need them to be registered - # explicitly, which is done below. - # - # The call to RegisterExtension will specifically check if the - # extension was already registered on the object and either - # ignore the registration if the original was the same, or raise - # an error if they were different. - - for name, extension in file_desc.extensions_by_name.items(): - if extension.containing_type.full_name not in self._classes: - self.GetPrototype(extension.containing_type) - extended_class = self._classes[extension.containing_type.full_name] - extended_class.RegisterExtension(extension) - return result - +_DB = descriptor_database.DescriptorDatabase() +_POOL = descriptor_pool.DescriptorPool(_DB) _FACTORY = MessageFactory() @@ -138,10 +82,32 @@ def GetMessages(file_protos): file_protos: A sequence of file protos to build messages out of. Returns: - A dictionary mapping proto names to the message classes. This will include - any dependent messages as well as any messages defined in the same file as - a specified message. + A dictionary containing all the message types in the files mapping the + fully qualified name to a Message subclass for the descriptor. """ + + result = {} + for file_proto in file_protos: + _DB.Add(file_proto) for file_proto in file_protos: - _FACTORY.pool.Add(file_proto) - return _FACTORY.GetMessages([file_proto.name for file_proto in file_protos]) + for desc in _GetAllDescriptors(file_proto.message_type, file_proto.package): + result[desc.full_name] = _FACTORY.GetPrototype(desc) + return result + + +def _GetAllDescriptors(desc_protos, package): + """Gets all levels of nested message types as a flattened list of descriptors. + + Args: + desc_protos: The descriptor protos to process. + package: The package where the protos are defined. + + Yields: + Each message descriptor for each nested type. + """ + + for desc_proto in desc_protos: + name = '.'.join((package, desc_proto.name)) + yield _POOL.FindMessageTypeByName(name) + for nested_desc in _GetAllDescriptors(desc_proto.nested_type, name): + yield nested_desc |