aboutsummaryrefslogtreecommitdiff
path: root/generator/google/protobuf/internal/python_message.py
diff options
context:
space:
mode:
Diffstat (limited to 'generator/google/protobuf/internal/python_message.py')
-rw-r--r--generator/google/protobuf/internal/python_message.py628
1 files changed, 114 insertions, 514 deletions
diff --git a/generator/google/protobuf/internal/python_message.py b/generator/google/protobuf/internal/python_message.py
index c0d0ad4..4bea57a 100644
--- a/generator/google/protobuf/internal/python_message.py
+++ b/generator/google/protobuf/internal/python_message.py
@@ -1,6 +1,6 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
-# https://developers.google.com/protocol-buffers/
+# http://code.google.com/p/protobuf/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
@@ -50,20 +50,13 @@ this file*.
__author__ = 'robinson@google.com (Will Robinson)'
-from io import BytesIO
-import sys
-import struct
-import weakref
-
-import six
try:
- import six.moves.copyreg as copyreg
+ from cStringIO import StringIO
except ImportError:
- # On some platforms, for example gMac, we run native Python because there is
- # nothing like hermetic Python. This means lesser control on the system and
- # the six.moves package may be missing (is missing on 20150321 on gMac). Be
- # extra conservative and try to load the old replacement if it fails.
- import copy_reg as copyreg
+ from StringIO import StringIO
+import copy_reg
+import struct
+import weakref
# We use "as" to avoid name collisions with variables.
from google.protobuf.internal import containers
@@ -72,116 +65,41 @@ from google.protobuf.internal import encoder
from google.protobuf.internal import enum_type_wrapper
from google.protobuf.internal import message_listener as message_listener_mod
from google.protobuf.internal import type_checkers
-from google.protobuf.internal import well_known_types
from google.protobuf.internal import wire_format
from google.protobuf import descriptor as descriptor_mod
from google.protobuf import message as message_mod
from google.protobuf import text_format
_FieldDescriptor = descriptor_mod.FieldDescriptor
-_AnyFullTypeName = 'google.protobuf.Any'
-class GeneratedProtocolMessageType(type):
+def NewMessage(bases, descriptor, dictionary):
+ _AddClassAttributesForNestedExtensions(descriptor, dictionary)
+ _AddSlots(descriptor, dictionary)
+ return bases
- """Metaclass for protocol message classes created at runtime from Descriptors.
-
- We add implementations for all methods described in the Message class. We
- also create properties to allow getting/setting all fields in the protocol
- message. Finally, we create slots to prevent users from accidentally
- "setting" nonexistent fields in the protocol message, which then wouldn't get
- serialized / deserialized properly.
-
- The protocol compiler currently uses this metaclass to create protocol
- message classes at runtime. Clients can also manually create their own
- classes at runtime, as in this example:
-
- mydescriptor = Descriptor(.....)
- factory = symbol_database.Default()
- factory.pool.AddDescriptor(mydescriptor)
- MyProtoClass = factory.GetPrototype(mydescriptor)
- myproto_instance = MyProtoClass()
- myproto.foo_field = 23
- ...
- """
- # Must be consistent with the protocol-compiler code in
- # proto2/compiler/internal/generator.*.
- _DESCRIPTOR_KEY = 'DESCRIPTOR'
+def InitMessage(descriptor, cls):
+ cls._decoders_by_tag = {}
+ cls._extensions_by_name = {}
+ cls._extensions_by_number = {}
+ if (descriptor.has_options and
+ descriptor.GetOptions().message_set_wire_format):
+ cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
+ decoder.MessageSetItemDecoder(cls._extensions_by_number))
- def __new__(cls, name, bases, dictionary):
- """Custom allocation for runtime-generated class types.
-
- We override __new__ because this is apparently the only place
- where we can meaningfully set __slots__ on the class we're creating(?).
- (The interplay between metaclasses and slots is not very well-documented).
-
- Args:
- name: Name of the class (ignored, but required by the
- metaclass protocol).
- bases: Base classes of the class we're constructing.
- (Should be message.Message). We ignore this field, but
- it's required by the metaclass protocol
- dictionary: The class dictionary of the class we're
- constructing. dictionary[_DESCRIPTOR_KEY] must contain
- a Descriptor object describing this protocol message
- type.
-
- Returns:
- Newly-allocated class.
- """
- descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
- if descriptor.full_name in well_known_types.WKTBASES:
- bases += (well_known_types.WKTBASES[descriptor.full_name],)
- _AddClassAttributesForNestedExtensions(descriptor, dictionary)
- _AddSlots(descriptor, dictionary)
-
- superclass = super(GeneratedProtocolMessageType, cls)
- new_class = superclass.__new__(cls, name, bases, dictionary)
- return new_class
-
- def __init__(cls, name, bases, dictionary):
- """Here we perform the majority of our work on the class.
- We add enum getters, an __init__ method, implementations
- of all Message methods, and properties for all fields
- in the protocol type.
+ # Attach stuff to each FieldDescriptor for quick lookup later on.
+ for field in descriptor.fields:
+ _AttachFieldHelpers(cls, field)
- Args:
- name: Name of the class (ignored, but required by the
- metaclass protocol).
- bases: Base classes of the class we're constructing.
- (Should be message.Message). We ignore this field, but
- it's required by the metaclass protocol
- dictionary: The class dictionary of the class we're
- constructing. dictionary[_DESCRIPTOR_KEY] must contain
- a Descriptor object describing this protocol message
- type.
- """
- descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
- cls._decoders_by_tag = {}
- cls._extensions_by_name = {}
- cls._extensions_by_number = {}
- if (descriptor.has_options and
- descriptor.GetOptions().message_set_wire_format):
- cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
- decoder.MessageSetItemDecoder(cls._extensions_by_number), None)
-
- # Attach stuff to each FieldDescriptor for quick lookup later on.
- for field in descriptor.fields:
- _AttachFieldHelpers(cls, field)
-
- descriptor._concrete_class = cls # pylint: disable=protected-access
- _AddEnumValues(descriptor, cls)
- _AddInitMethod(descriptor, cls)
- _AddPropertiesForFields(descriptor, cls)
- _AddPropertiesForExtensions(descriptor, cls)
- _AddStaticMethods(cls)
- _AddMessageMethods(descriptor, cls)
- _AddPrivateHelperMethods(descriptor, cls)
- copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
-
- superclass = super(GeneratedProtocolMessageType, cls)
- superclass.__init__(name, bases, dictionary)
+ _AddEnumValues(descriptor, cls)
+ _AddInitMethod(descriptor, cls)
+ _AddPropertiesForFields(descriptor, cls)
+ _AddPropertiesForExtensions(descriptor, cls)
+ _AddStaticMethods(cls)
+ _AddMessageMethods(descriptor, cls)
+ _AddPrivateHelperMethods(cls)
+ copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
# Stateless helpers for GeneratedProtocolMessageType below.
@@ -258,8 +176,7 @@ def _AddSlots(message_descriptor, dictionary):
'_is_present_in_parent',
'_listener',
'_listener_for_children',
- '__weakref__',
- '_oneofs']
+ '__weakref__']
def _IsMessageSetExtension(field):
@@ -267,40 +184,16 @@ def _IsMessageSetExtension(field):
field.containing_type.has_options and
field.containing_type.GetOptions().message_set_wire_format and
field.type == _FieldDescriptor.TYPE_MESSAGE and
+ field.message_type == field.extension_scope and
field.label == _FieldDescriptor.LABEL_OPTIONAL)
-def _IsMapField(field):
- return (field.type == _FieldDescriptor.TYPE_MESSAGE and
- field.message_type.has_options and
- field.message_type.GetOptions().map_entry)
-
-
-def _IsMessageMapField(field):
- value_type = field.message_type.fields_by_name["value"]
- return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
-
-
def _AttachFieldHelpers(cls, field_descriptor):
is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
- is_packable = (is_repeated and
- wire_format.IsTypePackable(field_descriptor.type))
- if not is_packable:
- is_packed = False
- elif field_descriptor.containing_type.syntax == "proto2":
- is_packed = (field_descriptor.has_options and
- field_descriptor.GetOptions().packed)
- else:
- has_packed_false = (field_descriptor.has_options and
- field_descriptor.GetOptions().HasField("packed") and
- field_descriptor.GetOptions().packed == False)
- is_packed = not has_packed_false
- is_map_entry = _IsMapField(field_descriptor)
-
- if is_map_entry:
- field_encoder = encoder.MapEncoder(field_descriptor)
- sizer = encoder.MapSizer(field_descriptor)
- elif _IsMessageSetExtension(field_descriptor):
+ is_packed = (field_descriptor.has_options and
+ field_descriptor.GetOptions().packed)
+
+ if _IsMessageSetExtension(field_descriptor):
field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
sizer = encoder.MessageSetItemSizer(field_descriptor.number)
else:
@@ -316,27 +209,10 @@ def _AttachFieldHelpers(cls, field_descriptor):
def AddDecoder(wiretype, is_packed):
tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
- decode_type = field_descriptor.type
- if (decode_type == _FieldDescriptor.TYPE_ENUM and
- type_checkers.SupportsOpenEnums(field_descriptor)):
- decode_type = _FieldDescriptor.TYPE_INT32
-
- oneof_descriptor = None
- if field_descriptor.containing_oneof is not None:
- oneof_descriptor = field_descriptor
-
- if is_map_entry:
- is_message_map = _IsMessageMapField(field_descriptor)
-
- field_decoder = decoder.MapDecoder(
- field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
- is_message_map)
- else:
- field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
- field_descriptor.number, is_repeated, is_packed,
- field_descriptor, field_descriptor._default_constructor)
-
- cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
+ cls._decoders_by_tag[tag_bytes] = (
+ type_checkers.TYPE_TO_DECODER[field_descriptor.type](
+ field_descriptor.number, is_repeated, is_packed,
+ field_descriptor, field_descriptor._default_constructor))
AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
False)
@@ -349,7 +225,7 @@ def _AttachFieldHelpers(cls, field_descriptor):
def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
extension_dict = descriptor.extensions_by_name
- for extension_name, extension_field in extension_dict.items():
+ for extension_name, extension_field in extension_dict.iteritems():
assert extension_name not in dictionary
dictionary[extension_name] = extension_field
@@ -369,26 +245,6 @@ def _AddEnumValues(descriptor, cls):
setattr(cls, enum_value.name, enum_value.number)
-def _GetInitializeDefaultForMap(field):
- if field.label != _FieldDescriptor.LABEL_REPEATED:
- raise ValueError('map_entry set on non-repeated field %s' % (
- field.name))
- fields_by_name = field.message_type.fields_by_name
- key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
-
- value_field = fields_by_name['value']
- if _IsMessageMapField(field):
- def MakeMessageMapDefault(message):
- return containers.MessageMap(
- message._listener_for_children, value_field.message_type, key_checker)
- return MakeMessageMapDefault
- else:
- value_checker = type_checkers.GetTypeChecker(value_field)
- def MakePrimitiveMapDefault(message):
- return containers.ScalarMap(
- message._listener_for_children, key_checker, value_checker)
- return MakePrimitiveMapDefault
-
def _DefaultValueConstructorForField(field):
"""Returns a function which returns a default value for a field.
@@ -403,9 +259,6 @@ def _DefaultValueConstructorForField(field):
value may refer back to |message| via a weak reference.
"""
- if _IsMapField(field):
- return _GetInitializeDefaultForMap(field)
-
if field.label == _FieldDescriptor.LABEL_REPEATED:
if field.has_default_value and field.default_value != []:
raise ValueError('Repeated field default value not empty list: %s' % (
@@ -419,7 +272,7 @@ def _DefaultValueConstructorForField(field):
message._listener_for_children, field.message_type)
return MakeRepeatedMessageDefault
else:
- type_checker = type_checkers.GetTypeChecker(field)
+ type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
def MakeRepeatedScalarDefault(message):
return containers.RepeatedScalarFieldContainer(
message._listener_for_children, type_checker)
@@ -430,10 +283,7 @@ def _DefaultValueConstructorForField(field):
message_type = field.message_type
def MakeSubMessageDefault(message):
result = message_type._concrete_class()
- result._SetListener(
- _OneofListener(message, field)
- if field.containing_oneof is not None
- else message._listener_for_children)
+ result._SetListener(message._listener_for_children)
return result
return MakeSubMessageDefault
@@ -444,95 +294,38 @@ def _DefaultValueConstructorForField(field):
return MakeScalarDefault
-def _ReraiseTypeErrorWithFieldName(message_name, field_name):
- """Re-raise the currently-handled TypeError with the field name added."""
- exc = sys.exc_info()[1]
- if len(exc.args) == 1 and type(exc) is TypeError:
- # simple TypeError; add field name to exception message
- exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
-
- # re-raise possibly-amended exception with original traceback:
- six.reraise(type(exc), exc, sys.exc_info()[2])
-
-
def _AddInitMethod(message_descriptor, cls):
"""Adds an __init__ method to cls."""
-
- def _GetIntegerEnumValue(enum_type, value):
- """Convert a string or integer enum value to an integer.
-
- If the value is a string, it is converted to the enum value in
- enum_type with the same name. If the value is not a string, it's
- returned as-is. (No conversion or bounds-checking is done.)
- """
- if isinstance(value, six.string_types):
- try:
- return enum_type.values_by_name[value].number
- except KeyError:
- raise ValueError('Enum type %s: unknown label "%s"' % (
- enum_type.full_name, value))
- return value
-
+ fields = message_descriptor.fields
def init(self, **kwargs):
self._cached_byte_size = 0
self._cached_byte_size_dirty = len(kwargs) > 0
self._fields = {}
- # Contains a mapping from oneof field descriptors to the descriptor
- # of the currently set field in that oneof field.
- self._oneofs = {}
-
# _unknown_fields is () when empty for efficiency, and will be turned into
# a list if fields are added.
self._unknown_fields = ()
self._is_present_in_parent = False
self._listener = message_listener_mod.NullMessageListener()
self._listener_for_children = _Listener(self)
- for field_name, field_value in kwargs.items():
+ for field_name, field_value in kwargs.iteritems():
field = _GetFieldByName(message_descriptor, field_name)
if field is None:
raise TypeError("%s() got an unexpected keyword argument '%s'" %
(message_descriptor.name, field_name))
- if field_value is None:
- # field=None is the same as no field at all.
- continue
if field.label == _FieldDescriptor.LABEL_REPEATED:
copy = field._default_constructor(self)
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
- if _IsMapField(field):
- if _IsMessageMapField(field):
- for key in field_value:
- copy[key].MergeFrom(field_value[key])
- else:
- copy.update(field_value)
- else:
- for val in field_value:
- if isinstance(val, dict):
- copy.add(**val)
- else:
- copy.add().MergeFrom(val)
+ for val in field_value:
+ copy.add().MergeFrom(val)
else: # Scalar
- if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
- field_value = [_GetIntegerEnumValue(field.enum_type, val)
- for val in field_value]
copy.extend(field_value)
self._fields[field] = copy
elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
copy = field._default_constructor(self)
- new_val = field_value
- if isinstance(field_value, dict):
- new_val = field.message_type._concrete_class(**field_value)
- try:
- copy.MergeFrom(new_val)
- except TypeError:
- _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
+ copy.MergeFrom(field_value)
self._fields[field] = copy
else:
- if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
- field_value = _GetIntegerEnumValue(field.enum_type, field_value)
- try:
- setattr(self, field_name, field_value)
- except TypeError:
- _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
+ setattr(self, field_name, field_value)
init.__module__ = None
init.__doc__ = None
@@ -551,8 +344,7 @@ def _GetFieldByName(message_descriptor, field_name):
try:
return message_descriptor.fields_by_name[field_name]
except KeyError:
- raise ValueError('Protocol message %s has no "%s" field.' %
- (message_descriptor.name, field_name))
+ raise ValueError('Protocol message has no "%s" field.' % field_name)
def _AddPropertiesForFields(descriptor, cls):
@@ -648,10 +440,9 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
"""
proto_field_name = field.name
property_name = _PropertyName(proto_field_name)
- type_checker = type_checkers.GetTypeChecker(field)
+ type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
default_value = field.default_value
valid_values = set()
- is_proto3 = field.containing_type.syntax == "proto3"
def getter(self):
# TODO(protobuf-team): This may be broken since there may not be
@@ -659,30 +450,14 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
return self._fields.get(field, default_value)
getter.__module__ = None
getter.__doc__ = 'Getter for %s.' % proto_field_name
-
- clear_when_set_to_default = is_proto3 and not field.containing_oneof
-
- def field_setter(self, new_value):
- # pylint: disable=protected-access
- # Testing the value for truthiness captures all of the proto3 defaults
- # (0, 0.0, enum 0, and False).
- new_value = type_checker.CheckValue(new_value)
- if clear_when_set_to_default and not new_value:
- self._fields.pop(field, None)
- else:
- self._fields[field] = new_value
+ def setter(self, new_value):
+ type_checker.CheckValue(new_value)
+ self._fields[field] = new_value
# Check _cached_byte_size_dirty inline to improve performance, since scalar
# setters are called frequently.
if not self._cached_byte_size_dirty:
self._Modified()
- if field.containing_oneof:
- def setter(self, new_value):
- field_setter(self, new_value)
- self._UpdateOneofState(field)
- else:
- setter = field_setter
-
setter.__module__ = None
setter.__doc__ = 'Setter for %s.' % proto_field_name
@@ -707,11 +482,18 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls):
proto_field_name = field.name
property_name = _PropertyName(proto_field_name)
+ # TODO(komarek): Can anyone explain to me why we cache the message_type this
+ # way, instead of referring to field.message_type inside of getter(self)?
+ # What if someone sets message_type later on (which makes for simpler
+ # dyanmic proto descriptor and class creation code).
+ message_type = field.message_type
+
def getter(self):
field_value = self._fields.get(field)
if field_value is None:
# Construct a new object to represent this field.
- field_value = field._default_constructor(self)
+ field_value = message_type._concrete_class() # use field.message_type?
+ field_value._SetListener(self._listener_for_children)
# Atomically check if another thread has preempted us and, if not, swap
# in the new object we just created. If someone has preempted us, we
@@ -738,7 +520,7 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls):
def _AddPropertiesForExtensions(descriptor, cls):
"""Adds properties for all fields in this protocol message type."""
extension_dict = descriptor.extensions_by_name
- for extension_name, extension_field in extension_dict.items():
+ for extension_name, extension_field in extension_dict.iteritems():
constant_name = extension_name.upper() + "_FIELD_NUMBER"
setattr(cls, constant_name, extension_field.number)
@@ -793,54 +575,33 @@ def _AddListFieldsMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
def ListFields(self):
- all_fields = [item for item in self._fields.items() if _IsPresent(item)]
+ all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)]
all_fields.sort(key = lambda item: item[0].number)
return all_fields
cls.ListFields = ListFields
-_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"'
-_Proto2HasError = 'Protocol message has no non-repeated field "%s"'
def _AddHasFieldMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
- is_proto3 = (message_descriptor.syntax == "proto3")
- error_msg = _Proto3HasError if is_proto3 else _Proto2HasError
-
- hassable_fields = {}
+ singular_fields = {}
for field in message_descriptor.fields:
- if field.label == _FieldDescriptor.LABEL_REPEATED:
- continue
- # For proto3, only submessages and fields inside a oneof have presence.
- if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and
- not field.containing_oneof):
- continue
- hassable_fields[field.name] = field
-
- if not is_proto3:
- # Fields inside oneofs are never repeated (enforced by the compiler).
- for oneof in message_descriptor.oneofs:
- hassable_fields[oneof.name] = oneof
+ if field.label != _FieldDescriptor.LABEL_REPEATED:
+ singular_fields[field.name] = field
def HasField(self, field_name):
try:
- field = hassable_fields[field_name]
+ field = singular_fields[field_name]
except KeyError:
- raise ValueError(error_msg % field_name)
+ raise ValueError(
+ 'Protocol message has no singular "%s" field.' % field_name)
- if isinstance(field, descriptor_mod.OneofDescriptor):
- try:
- return HasField(self, self._oneofs[field].name)
- except KeyError:
- return False
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ value = self._fields.get(field)
+ return value is not None and value._is_present_in_parent
else:
- if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- value = self._fields.get(field)
- return value is not None and value._is_present_in_parent
- else:
- return field in self._fields
-
+ return field in self._fields
cls.HasField = HasField
@@ -850,30 +611,14 @@ def _AddClearFieldMethod(message_descriptor, cls):
try:
field = message_descriptor.fields_by_name[field_name]
except KeyError:
- try:
- field = message_descriptor.oneofs_by_name[field_name]
- if field in self._oneofs:
- field = self._oneofs[field]
- else:
- return
- except KeyError:
- raise ValueError('Protocol message %s() has no "%s" field.' %
- (message_descriptor.name, field_name))
+ raise ValueError('Protocol message has no "%s" field.' % field_name)
if field in self._fields:
- # To match the C++ implementation, we need to invalidate iterators
- # for map fields when ClearField() happens.
- if hasattr(self._fields[field], 'InvalidateIterators'):
- self._fields[field].InvalidateIterators()
-
# Note: If the field is a sub-message, its listener will still point
# at us. That's fine, because the worst than can happen is that it
# will call _Modified() and invalidate our byte size. Big deal.
del self._fields[field]
- if self._oneofs.get(field.containing_oneof, None) is field:
- del self._oneofs[field.containing_oneof]
-
# Always call _Modified() -- even if nothing was changed, this is
# a mutating method, and thus calling it should cause the field to become
# present in the parent message.
@@ -894,6 +639,16 @@ def _AddClearExtensionMethod(cls):
cls.ClearExtension = ClearExtension
+def _AddClearMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+ def Clear(self):
+ # Clear fields.
+ self._fields = {}
+ self._unknown_fields = ()
+ self._Modified()
+ cls.Clear = Clear
+
+
def _AddHasExtensionMethod(cls):
"""Helper for _AddMessageMethods()."""
def HasExtension(self, extension_handle):
@@ -908,45 +663,6 @@ def _AddHasExtensionMethod(cls):
return extension_handle in self._fields
cls.HasExtension = HasExtension
-def _InternalUnpackAny(msg):
- """Unpacks Any message and returns the unpacked message.
-
- This internal method is differnt from public Any Unpack method which takes
- the target message as argument. _InternalUnpackAny method does not have
- target message type and need to find the message type in descriptor pool.
-
- Args:
- msg: An Any message to be unpacked.
-
- Returns:
- The unpacked message.
- """
- # TODO(amauryfa): Don't use the factory of generated messages.
- # To make Any work with custom factories, use the message factory of the
- # parent message.
- # pylint: disable=g-import-not-at-top
- from google.protobuf import symbol_database
- factory = symbol_database.Default()
-
- type_url = msg.type_url
-
- if not type_url:
- return None
-
- # TODO(haberman): For now we just strip the hostname. Better logic will be
- # required.
- type_name = type_url.split('/')[-1]
- descriptor = factory.pool.FindMessageTypeByName(type_name)
-
- if descriptor is None:
- return None
-
- message_class = factory.GetPrototype(descriptor)
- message = message_class()
-
- message.ParseFromString(msg.value)
- return message
-
def _AddEqualsMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
@@ -958,12 +674,6 @@ def _AddEqualsMethod(message_descriptor, cls):
if self is other:
return True
- if self.DESCRIPTOR.full_name == _AnyFullTypeName:
- any_a = _InternalUnpackAny(self)
- any_b = _InternalUnpackAny(other)
- if any_a and any_b:
- return any_a == any_b
-
if not self.ListFields() == other.ListFields():
return False
@@ -985,13 +695,6 @@ def _AddStrMethod(message_descriptor, cls):
cls.__str__ = __str__
-def _AddReprMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
- def __repr__(self):
- return text_format.MessageToString(self)
- cls.__repr__ = __repr__
-
-
def _AddUnicodeMethod(unused_message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
@@ -1000,6 +703,16 @@ def _AddUnicodeMethod(unused_message_descriptor, cls):
cls.__unicode__ = __unicode__
+def _AddSetListenerMethod(cls):
+ """Helper for _AddMessageMethods()."""
+ def SetListener(self, listener):
+ if listener is None:
+ self._listener = message_listener_mod.NullMessageListener()
+ else:
+ self._listener = listener
+ cls._SetListener = SetListener
+
+
def _BytesForNonRepeatedElement(value, field_number, field_type):
"""Returns the number of bytes needed to serialize a non-repeated element.
The returned byte count includes space for tag information and any
@@ -1060,7 +773,7 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
def SerializePartialToString(self):
- out = BytesIO()
+ out = StringIO()
self._InternalSerialize(out.write)
return out.getvalue()
cls.SerializePartialToString = SerializePartialToString
@@ -1083,10 +796,9 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
# The only reason _InternalParse would return early is if it
# encountered an end-group tag.
raise message_mod.DecodeError('Unexpected end-group tag.')
- except (IndexError, TypeError):
- # Now ord(buf[p:p+1]) == ord('') gets TypeError.
+ except IndexError:
raise message_mod.DecodeError('Truncated message.')
- except struct.error as e:
+ except struct.error, e:
raise message_mod.DecodeError(e)
return length # Return this for legacy reasons.
cls.MergeFromString = MergeFromString
@@ -1094,7 +806,6 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
local_ReadTag = decoder.ReadTag
local_SkipField = decoder.SkipField
decoders_by_tag = cls._decoders_by_tag
- is_proto3 = message_descriptor.syntax == "proto3"
def InternalParse(self, buffer, pos, end):
self._Modified()
@@ -1102,22 +813,18 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
unknown_field_list = self._unknown_fields
while pos != end:
(tag_bytes, new_pos) = local_ReadTag(buffer, pos)
- field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
+ field_decoder = decoders_by_tag.get(tag_bytes)
if field_decoder is None:
value_start_pos = new_pos
new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
if new_pos == -1:
return pos
- if not is_proto3:
- if not unknown_field_list:
- unknown_field_list = self._unknown_fields = []
- unknown_field_list.append(
- (tag_bytes, buffer[value_start_pos:new_pos]))
+ if not unknown_field_list:
+ unknown_field_list = self._unknown_fields = []
+ unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos]))
pos = new_pos
else:
pos = field_decoder(buffer, new_pos, end, self, field_dict)
- if field_desc:
- self._UpdateOneofState(field_desc)
return pos
cls._InternalParse = InternalParse
@@ -1150,12 +857,9 @@ def _AddIsInitializedMethod(message_descriptor, cls):
errors.extend(self.FindInitializationErrors())
return False
- for field, value in list(self._fields.items()): # dict can change size!
+ for field, value in self._fields.iteritems():
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
if field.label == _FieldDescriptor.LABEL_REPEATED:
- if (field.message_type.has_options and
- field.message_type.GetOptions().map_entry):
- continue
for element in value:
if not element.IsInitialized():
if errors is not None:
@@ -1191,26 +895,16 @@ def _AddIsInitializedMethod(message_descriptor, cls):
else:
name = field.name
- if _IsMapField(field):
- if _IsMessageMapField(field):
- for key in value:
- element = value[key]
- prefix = "%s[%s]." % (name, key)
- sub_errors = element.FindInitializationErrors()
- errors += [prefix + error for error in sub_errors]
- else:
- # ScalarMaps can't have any initialization errors.
- pass
- elif field.label == _FieldDescriptor.LABEL_REPEATED:
- for i in range(len(value)):
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ for i in xrange(len(value)):
element = value[i]
prefix = "%s[%d]." % (name, i)
sub_errors = element.FindInitializationErrors()
- errors += [prefix + error for error in sub_errors]
+ errors += [ prefix + error for error in sub_errors ]
else:
prefix = name + "."
sub_errors = value.FindInitializationErrors()
- errors += [prefix + error for error in sub_errors]
+ errors += [ prefix + error for error in sub_errors ]
return errors
@@ -1232,7 +926,7 @@ def _AddMergeFromMethod(cls):
fields = self._fields
- for field, value in msg._fields.items():
+ for field, value in msg._fields.iteritems():
if field.label == LABEL_REPEATED:
field_value = fields.get(field)
if field_value is None:
@@ -1250,8 +944,6 @@ def _AddMergeFromMethod(cls):
field_value.MergeFrom(value)
else:
self._fields[field] = value
- if field.containing_oneof:
- self._UpdateOneofState(field)
if msg._unknown_fields:
if not self._unknown_fields:
@@ -1261,50 +953,6 @@ def _AddMergeFromMethod(cls):
cls.MergeFrom = MergeFrom
-def _AddWhichOneofMethod(message_descriptor, cls):
- def WhichOneof(self, oneof_name):
- """Returns the name of the currently set field inside a oneof, or None."""
- try:
- field = message_descriptor.oneofs_by_name[oneof_name]
- except KeyError:
- raise ValueError(
- 'Protocol message has no oneof "%s" field.' % oneof_name)
-
- nested_field = self._oneofs.get(field, None)
- if nested_field is not None and self.HasField(nested_field.name):
- return nested_field.name
- else:
- return None
-
- cls.WhichOneof = WhichOneof
-
-
-def _Clear(self):
- # Clear fields.
- self._fields = {}
- self._unknown_fields = ()
- self._oneofs = {}
- self._Modified()
-
-
-def _DiscardUnknownFields(self):
- self._unknown_fields = []
- for field, value in self.ListFields():
- if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- if field.label == _FieldDescriptor.LABEL_REPEATED:
- for sub_message in value:
- sub_message.DiscardUnknownFields()
- else:
- value.DiscardUnknownFields()
-
-
-def _SetListener(self, listener):
- if listener is None:
- self._listener = message_listener_mod.NullMessageListener()
- else:
- self._listener = listener
-
-
def _AddMessageMethods(message_descriptor, cls):
"""Adds implementations of all Message methods to cls."""
_AddListFieldsMethod(message_descriptor, cls)
@@ -1313,24 +961,20 @@ def _AddMessageMethods(message_descriptor, cls):
if message_descriptor.is_extendable:
_AddClearExtensionMethod(cls)
_AddHasExtensionMethod(cls)
+ _AddClearMethod(message_descriptor, cls)
_AddEqualsMethod(message_descriptor, cls)
_AddStrMethod(message_descriptor, cls)
- _AddReprMethod(message_descriptor, cls)
_AddUnicodeMethod(message_descriptor, cls)
+ _AddSetListenerMethod(cls)
_AddByteSizeMethod(message_descriptor, cls)
_AddSerializeToStringMethod(message_descriptor, cls)
_AddSerializePartialToStringMethod(message_descriptor, cls)
_AddMergeFromStringMethod(message_descriptor, cls)
_AddIsInitializedMethod(message_descriptor, cls)
_AddMergeFromMethod(cls)
- _AddWhichOneofMethod(message_descriptor, cls)
- # Adds methods which do not depend on cls.
- cls.Clear = _Clear
- cls.DiscardUnknownFields = _DiscardUnknownFields
- cls._SetListener = _SetListener
-def _AddPrivateHelperMethods(message_descriptor, cls):
+def _AddPrivateHelperMethods(cls):
"""Adds implementation of private helper methods to cls."""
def Modified(self):
@@ -1348,20 +992,8 @@ def _AddPrivateHelperMethods(message_descriptor, cls):
self._is_present_in_parent = True
self._listener.Modified()
- def _UpdateOneofState(self, field):
- """Sets field as the active field in its containing oneof.
-
- Will also delete currently active field in the oneof, if it is different
- from the argument. Does not mark the message as modified.
- """
- other_field = self._oneofs.setdefault(field.containing_oneof, field)
- if other_field is not field:
- del self._fields[other_field]
- self._oneofs[field.containing_oneof] = field
-
cls._Modified = Modified
cls.SetInParent = Modified
- cls._UpdateOneofState = _UpdateOneofState
class _Listener(object):
@@ -1410,27 +1042,6 @@ class _Listener(object):
pass
-class _OneofListener(_Listener):
- """Special listener implementation for setting composite oneof fields."""
-
- def __init__(self, parent_message, field):
- """Args:
- parent_message: The message whose _Modified() method we should call when
- we receive Modified() messages.
- field: The descriptor of the field being set in the parent message.
- """
- super(_OneofListener, self).__init__(parent_message)
- self._field = field
-
- def Modified(self):
- """Also updates the state of the containing oneof in the parent message."""
- try:
- self._parent_message_weakref._UpdateOneofState(self._field)
- super(_OneofListener, self).Modified()
- except ReferenceError:
- pass
-
-
# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous...
# TODO(robinson): Unify error handling of "unknown extension" crap.
# TODO(robinson): Support iteritems()-style iteration over all
@@ -1521,10 +1132,10 @@ class _ExtensionDict(object):
# It's slightly wasteful to lookup the type checker each time,
# but we expect this to be a vanishingly uncommon case anyway.
- type_checker = type_checkers.GetTypeChecker(extension_handle)
- # pylint: disable=protected-access
- self._extended_message._fields[extension_handle] = (
- type_checker.CheckValue(value))
+ type_checker = type_checkers.GetTypeChecker(
+ extension_handle.cpp_type, extension_handle.type)
+ type_checker.CheckValue(value)
+ self._extended_message._fields[extension_handle] = value
self._extended_message._Modified()
def _FindExtensionByName(self, name):
@@ -1537,14 +1148,3 @@ class _ExtensionDict(object):
Extension field descriptor.
"""
return self._extended_message._extensions_by_name.get(name, None)
-
- def _FindExtensionByNumber(self, number):
- """Tries to find a known extension with the field number.
-
- Args:
- number: Extension field number.
-
- Returns:
- Extension field descriptor.
- """
- return self._extended_message._extensions_by_number.get(number, None)