diff options
Diffstat (limited to 'generator/google/protobuf/internal/python_message.py')
-rw-r--r-- | generator/google/protobuf/internal/python_message.py | 628 |
1 files changed, 114 insertions, 514 deletions
diff --git a/generator/google/protobuf/internal/python_message.py b/generator/google/protobuf/internal/python_message.py index c0d0ad4..4bea57a 100644 --- a/generator/google/protobuf/internal/python_message.py +++ b/generator/google/protobuf/internal/python_message.py @@ -1,6 +1,6 @@ # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. -# https://developers.google.com/protocol-buffers/ +# http://code.google.com/p/protobuf/ # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -50,20 +50,13 @@ this file*. __author__ = 'robinson@google.com (Will Robinson)' -from io import BytesIO -import sys -import struct -import weakref - -import six try: - import six.moves.copyreg as copyreg + from cStringIO import StringIO except ImportError: - # On some platforms, for example gMac, we run native Python because there is - # nothing like hermetic Python. This means lesser control on the system and - # the six.moves package may be missing (is missing on 20150321 on gMac). Be - # extra conservative and try to load the old replacement if it fails. - import copy_reg as copyreg + from StringIO import StringIO +import copy_reg +import struct +import weakref # We use "as" to avoid name collisions with variables. from google.protobuf.internal import containers @@ -72,116 +65,41 @@ from google.protobuf.internal import encoder from google.protobuf.internal import enum_type_wrapper from google.protobuf.internal import message_listener as message_listener_mod from google.protobuf.internal import type_checkers -from google.protobuf.internal import well_known_types from google.protobuf.internal import wire_format from google.protobuf import descriptor as descriptor_mod from google.protobuf import message as message_mod from google.protobuf import text_format _FieldDescriptor = descriptor_mod.FieldDescriptor -_AnyFullTypeName = 'google.protobuf.Any' -class GeneratedProtocolMessageType(type): +def NewMessage(bases, descriptor, dictionary): + _AddClassAttributesForNestedExtensions(descriptor, dictionary) + _AddSlots(descriptor, dictionary) + return bases - """Metaclass for protocol message classes created at runtime from Descriptors. - - We add implementations for all methods described in the Message class. We - also create properties to allow getting/setting all fields in the protocol - message. Finally, we create slots to prevent users from accidentally - "setting" nonexistent fields in the protocol message, which then wouldn't get - serialized / deserialized properly. - - The protocol compiler currently uses this metaclass to create protocol - message classes at runtime. Clients can also manually create their own - classes at runtime, as in this example: - - mydescriptor = Descriptor(.....) - factory = symbol_database.Default() - factory.pool.AddDescriptor(mydescriptor) - MyProtoClass = factory.GetPrototype(mydescriptor) - myproto_instance = MyProtoClass() - myproto.foo_field = 23 - ... - """ - # Must be consistent with the protocol-compiler code in - # proto2/compiler/internal/generator.*. - _DESCRIPTOR_KEY = 'DESCRIPTOR' +def InitMessage(descriptor, cls): + cls._decoders_by_tag = {} + cls._extensions_by_name = {} + cls._extensions_by_number = {} + if (descriptor.has_options and + descriptor.GetOptions().message_set_wire_format): + cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( + decoder.MessageSetItemDecoder(cls._extensions_by_number)) - def __new__(cls, name, bases, dictionary): - """Custom allocation for runtime-generated class types. - - We override __new__ because this is apparently the only place - where we can meaningfully set __slots__ on the class we're creating(?). - (The interplay between metaclasses and slots is not very well-documented). - - Args: - name: Name of the class (ignored, but required by the - metaclass protocol). - bases: Base classes of the class we're constructing. - (Should be message.Message). We ignore this field, but - it's required by the metaclass protocol - dictionary: The class dictionary of the class we're - constructing. dictionary[_DESCRIPTOR_KEY] must contain - a Descriptor object describing this protocol message - type. - - Returns: - Newly-allocated class. - """ - descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] - if descriptor.full_name in well_known_types.WKTBASES: - bases += (well_known_types.WKTBASES[descriptor.full_name],) - _AddClassAttributesForNestedExtensions(descriptor, dictionary) - _AddSlots(descriptor, dictionary) - - superclass = super(GeneratedProtocolMessageType, cls) - new_class = superclass.__new__(cls, name, bases, dictionary) - return new_class - - def __init__(cls, name, bases, dictionary): - """Here we perform the majority of our work on the class. - We add enum getters, an __init__ method, implementations - of all Message methods, and properties for all fields - in the protocol type. + # Attach stuff to each FieldDescriptor for quick lookup later on. + for field in descriptor.fields: + _AttachFieldHelpers(cls, field) - Args: - name: Name of the class (ignored, but required by the - metaclass protocol). - bases: Base classes of the class we're constructing. - (Should be message.Message). We ignore this field, but - it's required by the metaclass protocol - dictionary: The class dictionary of the class we're - constructing. dictionary[_DESCRIPTOR_KEY] must contain - a Descriptor object describing this protocol message - type. - """ - descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] - cls._decoders_by_tag = {} - cls._extensions_by_name = {} - cls._extensions_by_number = {} - if (descriptor.has_options and - descriptor.GetOptions().message_set_wire_format): - cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( - decoder.MessageSetItemDecoder(cls._extensions_by_number), None) - - # Attach stuff to each FieldDescriptor for quick lookup later on. - for field in descriptor.fields: - _AttachFieldHelpers(cls, field) - - descriptor._concrete_class = cls # pylint: disable=protected-access - _AddEnumValues(descriptor, cls) - _AddInitMethod(descriptor, cls) - _AddPropertiesForFields(descriptor, cls) - _AddPropertiesForExtensions(descriptor, cls) - _AddStaticMethods(cls) - _AddMessageMethods(descriptor, cls) - _AddPrivateHelperMethods(descriptor, cls) - copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) - - superclass = super(GeneratedProtocolMessageType, cls) - superclass.__init__(name, bases, dictionary) + _AddEnumValues(descriptor, cls) + _AddInitMethod(descriptor, cls) + _AddPropertiesForFields(descriptor, cls) + _AddPropertiesForExtensions(descriptor, cls) + _AddStaticMethods(cls) + _AddMessageMethods(descriptor, cls) + _AddPrivateHelperMethods(cls) + copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) # Stateless helpers for GeneratedProtocolMessageType below. @@ -258,8 +176,7 @@ def _AddSlots(message_descriptor, dictionary): '_is_present_in_parent', '_listener', '_listener_for_children', - '__weakref__', - '_oneofs'] + '__weakref__'] def _IsMessageSetExtension(field): @@ -267,40 +184,16 @@ def _IsMessageSetExtension(field): field.containing_type.has_options and field.containing_type.GetOptions().message_set_wire_format and field.type == _FieldDescriptor.TYPE_MESSAGE and + field.message_type == field.extension_scope and field.label == _FieldDescriptor.LABEL_OPTIONAL) -def _IsMapField(field): - return (field.type == _FieldDescriptor.TYPE_MESSAGE and - field.message_type.has_options and - field.message_type.GetOptions().map_entry) - - -def _IsMessageMapField(field): - value_type = field.message_type.fields_by_name["value"] - return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE - - def _AttachFieldHelpers(cls, field_descriptor): is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) - is_packable = (is_repeated and - wire_format.IsTypePackable(field_descriptor.type)) - if not is_packable: - is_packed = False - elif field_descriptor.containing_type.syntax == "proto2": - is_packed = (field_descriptor.has_options and - field_descriptor.GetOptions().packed) - else: - has_packed_false = (field_descriptor.has_options and - field_descriptor.GetOptions().HasField("packed") and - field_descriptor.GetOptions().packed == False) - is_packed = not has_packed_false - is_map_entry = _IsMapField(field_descriptor) - - if is_map_entry: - field_encoder = encoder.MapEncoder(field_descriptor) - sizer = encoder.MapSizer(field_descriptor) - elif _IsMessageSetExtension(field_descriptor): + is_packed = (field_descriptor.has_options and + field_descriptor.GetOptions().packed) + + if _IsMessageSetExtension(field_descriptor): field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) sizer = encoder.MessageSetItemSizer(field_descriptor.number) else: @@ -316,27 +209,10 @@ def _AttachFieldHelpers(cls, field_descriptor): def AddDecoder(wiretype, is_packed): tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) - decode_type = field_descriptor.type - if (decode_type == _FieldDescriptor.TYPE_ENUM and - type_checkers.SupportsOpenEnums(field_descriptor)): - decode_type = _FieldDescriptor.TYPE_INT32 - - oneof_descriptor = None - if field_descriptor.containing_oneof is not None: - oneof_descriptor = field_descriptor - - if is_map_entry: - is_message_map = _IsMessageMapField(field_descriptor) - - field_decoder = decoder.MapDecoder( - field_descriptor, _GetInitializeDefaultForMap(field_descriptor), - is_message_map) - else: - field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( - field_descriptor.number, is_repeated, is_packed, - field_descriptor, field_descriptor._default_constructor) - - cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) + cls._decoders_by_tag[tag_bytes] = ( + type_checkers.TYPE_TO_DECODER[field_descriptor.type]( + field_descriptor.number, is_repeated, is_packed, + field_descriptor, field_descriptor._default_constructor)) AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False) @@ -349,7 +225,7 @@ def _AttachFieldHelpers(cls, field_descriptor): def _AddClassAttributesForNestedExtensions(descriptor, dictionary): extension_dict = descriptor.extensions_by_name - for extension_name, extension_field in extension_dict.items(): + for extension_name, extension_field in extension_dict.iteritems(): assert extension_name not in dictionary dictionary[extension_name] = extension_field @@ -369,26 +245,6 @@ def _AddEnumValues(descriptor, cls): setattr(cls, enum_value.name, enum_value.number) -def _GetInitializeDefaultForMap(field): - if field.label != _FieldDescriptor.LABEL_REPEATED: - raise ValueError('map_entry set on non-repeated field %s' % ( - field.name)) - fields_by_name = field.message_type.fields_by_name - key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) - - value_field = fields_by_name['value'] - if _IsMessageMapField(field): - def MakeMessageMapDefault(message): - return containers.MessageMap( - message._listener_for_children, value_field.message_type, key_checker) - return MakeMessageMapDefault - else: - value_checker = type_checkers.GetTypeChecker(value_field) - def MakePrimitiveMapDefault(message): - return containers.ScalarMap( - message._listener_for_children, key_checker, value_checker) - return MakePrimitiveMapDefault - def _DefaultValueConstructorForField(field): """Returns a function which returns a default value for a field. @@ -403,9 +259,6 @@ def _DefaultValueConstructorForField(field): value may refer back to |message| via a weak reference. """ - if _IsMapField(field): - return _GetInitializeDefaultForMap(field) - if field.label == _FieldDescriptor.LABEL_REPEATED: if field.has_default_value and field.default_value != []: raise ValueError('Repeated field default value not empty list: %s' % ( @@ -419,7 +272,7 @@ def _DefaultValueConstructorForField(field): message._listener_for_children, field.message_type) return MakeRepeatedMessageDefault else: - type_checker = type_checkers.GetTypeChecker(field) + type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) def MakeRepeatedScalarDefault(message): return containers.RepeatedScalarFieldContainer( message._listener_for_children, type_checker) @@ -430,10 +283,7 @@ def _DefaultValueConstructorForField(field): message_type = field.message_type def MakeSubMessageDefault(message): result = message_type._concrete_class() - result._SetListener( - _OneofListener(message, field) - if field.containing_oneof is not None - else message._listener_for_children) + result._SetListener(message._listener_for_children) return result return MakeSubMessageDefault @@ -444,95 +294,38 @@ def _DefaultValueConstructorForField(field): return MakeScalarDefault -def _ReraiseTypeErrorWithFieldName(message_name, field_name): - """Re-raise the currently-handled TypeError with the field name added.""" - exc = sys.exc_info()[1] - if len(exc.args) == 1 and type(exc) is TypeError: - # simple TypeError; add field name to exception message - exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name)) - - # re-raise possibly-amended exception with original traceback: - six.reraise(type(exc), exc, sys.exc_info()[2]) - - def _AddInitMethod(message_descriptor, cls): """Adds an __init__ method to cls.""" - - def _GetIntegerEnumValue(enum_type, value): - """Convert a string or integer enum value to an integer. - - If the value is a string, it is converted to the enum value in - enum_type with the same name. If the value is not a string, it's - returned as-is. (No conversion or bounds-checking is done.) - """ - if isinstance(value, six.string_types): - try: - return enum_type.values_by_name[value].number - except KeyError: - raise ValueError('Enum type %s: unknown label "%s"' % ( - enum_type.full_name, value)) - return value - + fields = message_descriptor.fields def init(self, **kwargs): self._cached_byte_size = 0 self._cached_byte_size_dirty = len(kwargs) > 0 self._fields = {} - # Contains a mapping from oneof field descriptors to the descriptor - # of the currently set field in that oneof field. - self._oneofs = {} - # _unknown_fields is () when empty for efficiency, and will be turned into # a list if fields are added. self._unknown_fields = () self._is_present_in_parent = False self._listener = message_listener_mod.NullMessageListener() self._listener_for_children = _Listener(self) - for field_name, field_value in kwargs.items(): + for field_name, field_value in kwargs.iteritems(): field = _GetFieldByName(message_descriptor, field_name) if field is None: raise TypeError("%s() got an unexpected keyword argument '%s'" % (message_descriptor.name, field_name)) - if field_value is None: - # field=None is the same as no field at all. - continue if field.label == _FieldDescriptor.LABEL_REPEATED: copy = field._default_constructor(self) if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite - if _IsMapField(field): - if _IsMessageMapField(field): - for key in field_value: - copy[key].MergeFrom(field_value[key]) - else: - copy.update(field_value) - else: - for val in field_value: - if isinstance(val, dict): - copy.add(**val) - else: - copy.add().MergeFrom(val) + for val in field_value: + copy.add().MergeFrom(val) else: # Scalar - if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: - field_value = [_GetIntegerEnumValue(field.enum_type, val) - for val in field_value] copy.extend(field_value) self._fields[field] = copy elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: copy = field._default_constructor(self) - new_val = field_value - if isinstance(field_value, dict): - new_val = field.message_type._concrete_class(**field_value) - try: - copy.MergeFrom(new_val) - except TypeError: - _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) + copy.MergeFrom(field_value) self._fields[field] = copy else: - if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: - field_value = _GetIntegerEnumValue(field.enum_type, field_value) - try: - setattr(self, field_name, field_value) - except TypeError: - _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) + setattr(self, field_name, field_value) init.__module__ = None init.__doc__ = None @@ -551,8 +344,7 @@ def _GetFieldByName(message_descriptor, field_name): try: return message_descriptor.fields_by_name[field_name] except KeyError: - raise ValueError('Protocol message %s has no "%s" field.' % - (message_descriptor.name, field_name)) + raise ValueError('Protocol message has no "%s" field.' % field_name) def _AddPropertiesForFields(descriptor, cls): @@ -648,10 +440,9 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): """ proto_field_name = field.name property_name = _PropertyName(proto_field_name) - type_checker = type_checkers.GetTypeChecker(field) + type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) default_value = field.default_value valid_values = set() - is_proto3 = field.containing_type.syntax == "proto3" def getter(self): # TODO(protobuf-team): This may be broken since there may not be @@ -659,30 +450,14 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): return self._fields.get(field, default_value) getter.__module__ = None getter.__doc__ = 'Getter for %s.' % proto_field_name - - clear_when_set_to_default = is_proto3 and not field.containing_oneof - - def field_setter(self, new_value): - # pylint: disable=protected-access - # Testing the value for truthiness captures all of the proto3 defaults - # (0, 0.0, enum 0, and False). - new_value = type_checker.CheckValue(new_value) - if clear_when_set_to_default and not new_value: - self._fields.pop(field, None) - else: - self._fields[field] = new_value + def setter(self, new_value): + type_checker.CheckValue(new_value) + self._fields[field] = new_value # Check _cached_byte_size_dirty inline to improve performance, since scalar # setters are called frequently. if not self._cached_byte_size_dirty: self._Modified() - if field.containing_oneof: - def setter(self, new_value): - field_setter(self, new_value) - self._UpdateOneofState(field) - else: - setter = field_setter - setter.__module__ = None setter.__doc__ = 'Setter for %s.' % proto_field_name @@ -707,11 +482,18 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls): proto_field_name = field.name property_name = _PropertyName(proto_field_name) + # TODO(komarek): Can anyone explain to me why we cache the message_type this + # way, instead of referring to field.message_type inside of getter(self)? + # What if someone sets message_type later on (which makes for simpler + # dyanmic proto descriptor and class creation code). + message_type = field.message_type + def getter(self): field_value = self._fields.get(field) if field_value is None: # Construct a new object to represent this field. - field_value = field._default_constructor(self) + field_value = message_type._concrete_class() # use field.message_type? + field_value._SetListener(self._listener_for_children) # Atomically check if another thread has preempted us and, if not, swap # in the new object we just created. If someone has preempted us, we @@ -738,7 +520,7 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls): def _AddPropertiesForExtensions(descriptor, cls): """Adds properties for all fields in this protocol message type.""" extension_dict = descriptor.extensions_by_name - for extension_name, extension_field in extension_dict.items(): + for extension_name, extension_field in extension_dict.iteritems(): constant_name = extension_name.upper() + "_FIELD_NUMBER" setattr(cls, constant_name, extension_field.number) @@ -793,54 +575,33 @@ def _AddListFieldsMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" def ListFields(self): - all_fields = [item for item in self._fields.items() if _IsPresent(item)] + all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)] all_fields.sort(key = lambda item: item[0].number) return all_fields cls.ListFields = ListFields -_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"' -_Proto2HasError = 'Protocol message has no non-repeated field "%s"' def _AddHasFieldMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" - is_proto3 = (message_descriptor.syntax == "proto3") - error_msg = _Proto3HasError if is_proto3 else _Proto2HasError - - hassable_fields = {} + singular_fields = {} for field in message_descriptor.fields: - if field.label == _FieldDescriptor.LABEL_REPEATED: - continue - # For proto3, only submessages and fields inside a oneof have presence. - if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and - not field.containing_oneof): - continue - hassable_fields[field.name] = field - - if not is_proto3: - # Fields inside oneofs are never repeated (enforced by the compiler). - for oneof in message_descriptor.oneofs: - hassable_fields[oneof.name] = oneof + if field.label != _FieldDescriptor.LABEL_REPEATED: + singular_fields[field.name] = field def HasField(self, field_name): try: - field = hassable_fields[field_name] + field = singular_fields[field_name] except KeyError: - raise ValueError(error_msg % field_name) + raise ValueError( + 'Protocol message has no singular "%s" field.' % field_name) - if isinstance(field, descriptor_mod.OneofDescriptor): - try: - return HasField(self, self._oneofs[field].name) - except KeyError: - return False + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + value = self._fields.get(field) + return value is not None and value._is_present_in_parent else: - if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - value = self._fields.get(field) - return value is not None and value._is_present_in_parent - else: - return field in self._fields - + return field in self._fields cls.HasField = HasField @@ -850,30 +611,14 @@ def _AddClearFieldMethod(message_descriptor, cls): try: field = message_descriptor.fields_by_name[field_name] except KeyError: - try: - field = message_descriptor.oneofs_by_name[field_name] - if field in self._oneofs: - field = self._oneofs[field] - else: - return - except KeyError: - raise ValueError('Protocol message %s() has no "%s" field.' % - (message_descriptor.name, field_name)) + raise ValueError('Protocol message has no "%s" field.' % field_name) if field in self._fields: - # To match the C++ implementation, we need to invalidate iterators - # for map fields when ClearField() happens. - if hasattr(self._fields[field], 'InvalidateIterators'): - self._fields[field].InvalidateIterators() - # Note: If the field is a sub-message, its listener will still point # at us. That's fine, because the worst than can happen is that it # will call _Modified() and invalidate our byte size. Big deal. del self._fields[field] - if self._oneofs.get(field.containing_oneof, None) is field: - del self._oneofs[field.containing_oneof] - # Always call _Modified() -- even if nothing was changed, this is # a mutating method, and thus calling it should cause the field to become # present in the parent message. @@ -894,6 +639,16 @@ def _AddClearExtensionMethod(cls): cls.ClearExtension = ClearExtension +def _AddClearMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def Clear(self): + # Clear fields. + self._fields = {} + self._unknown_fields = () + self._Modified() + cls.Clear = Clear + + def _AddHasExtensionMethod(cls): """Helper for _AddMessageMethods().""" def HasExtension(self, extension_handle): @@ -908,45 +663,6 @@ def _AddHasExtensionMethod(cls): return extension_handle in self._fields cls.HasExtension = HasExtension -def _InternalUnpackAny(msg): - """Unpacks Any message and returns the unpacked message. - - This internal method is differnt from public Any Unpack method which takes - the target message as argument. _InternalUnpackAny method does not have - target message type and need to find the message type in descriptor pool. - - Args: - msg: An Any message to be unpacked. - - Returns: - The unpacked message. - """ - # TODO(amauryfa): Don't use the factory of generated messages. - # To make Any work with custom factories, use the message factory of the - # parent message. - # pylint: disable=g-import-not-at-top - from google.protobuf import symbol_database - factory = symbol_database.Default() - - type_url = msg.type_url - - if not type_url: - return None - - # TODO(haberman): For now we just strip the hostname. Better logic will be - # required. - type_name = type_url.split('/')[-1] - descriptor = factory.pool.FindMessageTypeByName(type_name) - - if descriptor is None: - return None - - message_class = factory.GetPrototype(descriptor) - message = message_class() - - message.ParseFromString(msg.value) - return message - def _AddEqualsMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" @@ -958,12 +674,6 @@ def _AddEqualsMethod(message_descriptor, cls): if self is other: return True - if self.DESCRIPTOR.full_name == _AnyFullTypeName: - any_a = _InternalUnpackAny(self) - any_b = _InternalUnpackAny(other) - if any_a and any_b: - return any_a == any_b - if not self.ListFields() == other.ListFields(): return False @@ -985,13 +695,6 @@ def _AddStrMethod(message_descriptor, cls): cls.__str__ = __str__ -def _AddReprMethod(message_descriptor, cls): - """Helper for _AddMessageMethods().""" - def __repr__(self): - return text_format.MessageToString(self) - cls.__repr__ = __repr__ - - def _AddUnicodeMethod(unused_message_descriptor, cls): """Helper for _AddMessageMethods().""" @@ -1000,6 +703,16 @@ def _AddUnicodeMethod(unused_message_descriptor, cls): cls.__unicode__ = __unicode__ +def _AddSetListenerMethod(cls): + """Helper for _AddMessageMethods().""" + def SetListener(self, listener): + if listener is None: + self._listener = message_listener_mod.NullMessageListener() + else: + self._listener = listener + cls._SetListener = SetListener + + def _BytesForNonRepeatedElement(value, field_number, field_type): """Returns the number of bytes needed to serialize a non-repeated element. The returned byte count includes space for tag information and any @@ -1060,7 +773,7 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" def SerializePartialToString(self): - out = BytesIO() + out = StringIO() self._InternalSerialize(out.write) return out.getvalue() cls.SerializePartialToString = SerializePartialToString @@ -1083,10 +796,9 @@ def _AddMergeFromStringMethod(message_descriptor, cls): # The only reason _InternalParse would return early is if it # encountered an end-group tag. raise message_mod.DecodeError('Unexpected end-group tag.') - except (IndexError, TypeError): - # Now ord(buf[p:p+1]) == ord('') gets TypeError. + except IndexError: raise message_mod.DecodeError('Truncated message.') - except struct.error as e: + except struct.error, e: raise message_mod.DecodeError(e) return length # Return this for legacy reasons. cls.MergeFromString = MergeFromString @@ -1094,7 +806,6 @@ def _AddMergeFromStringMethod(message_descriptor, cls): local_ReadTag = decoder.ReadTag local_SkipField = decoder.SkipField decoders_by_tag = cls._decoders_by_tag - is_proto3 = message_descriptor.syntax == "proto3" def InternalParse(self, buffer, pos, end): self._Modified() @@ -1102,22 +813,18 @@ def _AddMergeFromStringMethod(message_descriptor, cls): unknown_field_list = self._unknown_fields while pos != end: (tag_bytes, new_pos) = local_ReadTag(buffer, pos) - field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) + field_decoder = decoders_by_tag.get(tag_bytes) if field_decoder is None: value_start_pos = new_pos new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) if new_pos == -1: return pos - if not is_proto3: - if not unknown_field_list: - unknown_field_list = self._unknown_fields = [] - unknown_field_list.append( - (tag_bytes, buffer[value_start_pos:new_pos])) + if not unknown_field_list: + unknown_field_list = self._unknown_fields = [] + unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) pos = new_pos else: pos = field_decoder(buffer, new_pos, end, self, field_dict) - if field_desc: - self._UpdateOneofState(field_desc) return pos cls._InternalParse = InternalParse @@ -1150,12 +857,9 @@ def _AddIsInitializedMethod(message_descriptor, cls): errors.extend(self.FindInitializationErrors()) return False - for field, value in list(self._fields.items()): # dict can change size! + for field, value in self._fields.iteritems(): if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: if field.label == _FieldDescriptor.LABEL_REPEATED: - if (field.message_type.has_options and - field.message_type.GetOptions().map_entry): - continue for element in value: if not element.IsInitialized(): if errors is not None: @@ -1191,26 +895,16 @@ def _AddIsInitializedMethod(message_descriptor, cls): else: name = field.name - if _IsMapField(field): - if _IsMessageMapField(field): - for key in value: - element = value[key] - prefix = "%s[%s]." % (name, key) - sub_errors = element.FindInitializationErrors() - errors += [prefix + error for error in sub_errors] - else: - # ScalarMaps can't have any initialization errors. - pass - elif field.label == _FieldDescriptor.LABEL_REPEATED: - for i in range(len(value)): + if field.label == _FieldDescriptor.LABEL_REPEATED: + for i in xrange(len(value)): element = value[i] prefix = "%s[%d]." % (name, i) sub_errors = element.FindInitializationErrors() - errors += [prefix + error for error in sub_errors] + errors += [ prefix + error for error in sub_errors ] else: prefix = name + "." sub_errors = value.FindInitializationErrors() - errors += [prefix + error for error in sub_errors] + errors += [ prefix + error for error in sub_errors ] return errors @@ -1232,7 +926,7 @@ def _AddMergeFromMethod(cls): fields = self._fields - for field, value in msg._fields.items(): + for field, value in msg._fields.iteritems(): if field.label == LABEL_REPEATED: field_value = fields.get(field) if field_value is None: @@ -1250,8 +944,6 @@ def _AddMergeFromMethod(cls): field_value.MergeFrom(value) else: self._fields[field] = value - if field.containing_oneof: - self._UpdateOneofState(field) if msg._unknown_fields: if not self._unknown_fields: @@ -1261,50 +953,6 @@ def _AddMergeFromMethod(cls): cls.MergeFrom = MergeFrom -def _AddWhichOneofMethod(message_descriptor, cls): - def WhichOneof(self, oneof_name): - """Returns the name of the currently set field inside a oneof, or None.""" - try: - field = message_descriptor.oneofs_by_name[oneof_name] - except KeyError: - raise ValueError( - 'Protocol message has no oneof "%s" field.' % oneof_name) - - nested_field = self._oneofs.get(field, None) - if nested_field is not None and self.HasField(nested_field.name): - return nested_field.name - else: - return None - - cls.WhichOneof = WhichOneof - - -def _Clear(self): - # Clear fields. - self._fields = {} - self._unknown_fields = () - self._oneofs = {} - self._Modified() - - -def _DiscardUnknownFields(self): - self._unknown_fields = [] - for field, value in self.ListFields(): - if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - if field.label == _FieldDescriptor.LABEL_REPEATED: - for sub_message in value: - sub_message.DiscardUnknownFields() - else: - value.DiscardUnknownFields() - - -def _SetListener(self, listener): - if listener is None: - self._listener = message_listener_mod.NullMessageListener() - else: - self._listener = listener - - def _AddMessageMethods(message_descriptor, cls): """Adds implementations of all Message methods to cls.""" _AddListFieldsMethod(message_descriptor, cls) @@ -1313,24 +961,20 @@ def _AddMessageMethods(message_descriptor, cls): if message_descriptor.is_extendable: _AddClearExtensionMethod(cls) _AddHasExtensionMethod(cls) + _AddClearMethod(message_descriptor, cls) _AddEqualsMethod(message_descriptor, cls) _AddStrMethod(message_descriptor, cls) - _AddReprMethod(message_descriptor, cls) _AddUnicodeMethod(message_descriptor, cls) + _AddSetListenerMethod(cls) _AddByteSizeMethod(message_descriptor, cls) _AddSerializeToStringMethod(message_descriptor, cls) _AddSerializePartialToStringMethod(message_descriptor, cls) _AddMergeFromStringMethod(message_descriptor, cls) _AddIsInitializedMethod(message_descriptor, cls) _AddMergeFromMethod(cls) - _AddWhichOneofMethod(message_descriptor, cls) - # Adds methods which do not depend on cls. - cls.Clear = _Clear - cls.DiscardUnknownFields = _DiscardUnknownFields - cls._SetListener = _SetListener -def _AddPrivateHelperMethods(message_descriptor, cls): +def _AddPrivateHelperMethods(cls): """Adds implementation of private helper methods to cls.""" def Modified(self): @@ -1348,20 +992,8 @@ def _AddPrivateHelperMethods(message_descriptor, cls): self._is_present_in_parent = True self._listener.Modified() - def _UpdateOneofState(self, field): - """Sets field as the active field in its containing oneof. - - Will also delete currently active field in the oneof, if it is different - from the argument. Does not mark the message as modified. - """ - other_field = self._oneofs.setdefault(field.containing_oneof, field) - if other_field is not field: - del self._fields[other_field] - self._oneofs[field.containing_oneof] = field - cls._Modified = Modified cls.SetInParent = Modified - cls._UpdateOneofState = _UpdateOneofState class _Listener(object): @@ -1410,27 +1042,6 @@ class _Listener(object): pass -class _OneofListener(_Listener): - """Special listener implementation for setting composite oneof fields.""" - - def __init__(self, parent_message, field): - """Args: - parent_message: The message whose _Modified() method we should call when - we receive Modified() messages. - field: The descriptor of the field being set in the parent message. - """ - super(_OneofListener, self).__init__(parent_message) - self._field = field - - def Modified(self): - """Also updates the state of the containing oneof in the parent message.""" - try: - self._parent_message_weakref._UpdateOneofState(self._field) - super(_OneofListener, self).Modified() - except ReferenceError: - pass - - # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... # TODO(robinson): Unify error handling of "unknown extension" crap. # TODO(robinson): Support iteritems()-style iteration over all @@ -1521,10 +1132,10 @@ class _ExtensionDict(object): # It's slightly wasteful to lookup the type checker each time, # but we expect this to be a vanishingly uncommon case anyway. - type_checker = type_checkers.GetTypeChecker(extension_handle) - # pylint: disable=protected-access - self._extended_message._fields[extension_handle] = ( - type_checker.CheckValue(value)) + type_checker = type_checkers.GetTypeChecker( + extension_handle.cpp_type, extension_handle.type) + type_checker.CheckValue(value) + self._extended_message._fields[extension_handle] = value self._extended_message._Modified() def _FindExtensionByName(self, name): @@ -1537,14 +1148,3 @@ class _ExtensionDict(object): Extension field descriptor. """ return self._extended_message._extensions_by_name.get(name, None) - - def _FindExtensionByNumber(self, number): - """Tries to find a known extension with the field number. - - Args: - number: Extension field number. - - Returns: - Extension field descriptor. - """ - return self._extended_message._extensions_by_number.get(number, None) |