diff options
Diffstat (limited to 'generator/google/protobuf/internal/decoder.py')
-rw-r--r-- | generator/google/protobuf/internal/decoder.py | 194 |
1 files changed, 164 insertions, 30 deletions
diff --git a/generator/google/protobuf/internal/decoder.py b/generator/google/protobuf/internal/decoder.py index cb6f572..31869e4 100644 --- a/generator/google/protobuf/internal/decoder.py +++ b/generator/google/protobuf/internal/decoder.py @@ -1,6 +1,6 @@ # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ +# https://developers.google.com/protocol-buffers/ # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -81,6 +81,12 @@ we repeatedly read a tag, look up the corresponding decoder, and invoke it. __author__ = 'kenton@google.com (Kenton Varda)' import struct + +import six + +if six.PY3: + long = int + from google.protobuf.internal import encoder from google.protobuf.internal import wire_format from google.protobuf import message @@ -98,7 +104,7 @@ _NAN = _POS_INF * 0 _DecodeError = message.DecodeError -def _VarintDecoder(mask): +def _VarintDecoder(mask, result_type): """Return an encoder for a basic varint value (does not include tag). Decoded values will be bitwise-anded with the given mask before being @@ -108,16 +114,16 @@ def _VarintDecoder(mask): decoder returns a (value, new_pos) pair. """ - local_ord = ord def DecodeVarint(buffer, pos): result = 0 shift = 0 while 1: - b = local_ord(buffer[pos]) + b = six.indexbytes(buffer, pos) result |= ((b & 0x7f) << shift) pos += 1 if not (b & 0x80): result &= mask + result = result_type(result) return (result, pos) shift += 7 if shift >= 64: @@ -125,15 +131,14 @@ def _VarintDecoder(mask): return DecodeVarint -def _SignedVarintDecoder(mask): +def _SignedVarintDecoder(mask, result_type): """Like _VarintDecoder() but decodes signed values.""" - local_ord = ord def DecodeVarint(buffer, pos): result = 0 shift = 0 while 1: - b = local_ord(buffer[pos]) + b = six.indexbytes(buffer, pos) result |= ((b & 0x7f) << shift) pos += 1 if not (b & 0x80): @@ -142,19 +147,23 @@ def _SignedVarintDecoder(mask): result |= ~mask else: result &= mask + result = result_type(result) return (result, pos) shift += 7 if shift >= 64: raise _DecodeError('Too many bytes when decoding varint.') return DecodeVarint +# We force 32-bit values to int and 64-bit values to long to make +# alternate implementations where the distinction is more significant +# (e.g. the C++ implementation) simpler. -_DecodeVarint = _VarintDecoder((1 << 64) - 1) -_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1) +_DecodeVarint = _VarintDecoder((1 << 64) - 1, long) +_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1, long) # Use these versions for values which must be limited to 32 bits. -_DecodeVarint32 = _VarintDecoder((1 << 32) - 1) -_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1) +_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) +_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1, int) def ReadTag(buffer, pos): @@ -169,7 +178,7 @@ def ReadTag(buffer, pos): """ start = pos - while ord(buffer[pos]) & 0x80: + while six.indexbytes(buffer, pos) & 0x80: pos += 1 pos += 1 return (buffer[start:pos], pos) @@ -294,13 +303,12 @@ def _FloatDecoder(): # If this value has all its exponent bits set, then it's non-finite. # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. # To avoid that, we parse it specially. - if ((float_bytes[3] in '\x7F\xFF') - and (float_bytes[2] >= '\x80')): + if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'): # If at least one significand bit is set... - if float_bytes[0:3] != '\x00\x00\x80': + if float_bytes[0:3] != b'\x00\x00\x80': return (_NAN, new_pos) # If sign bit is set... - if float_bytes[3] == '\xFF': + if float_bytes[3:4] == b'\xFF': return (_NEG_INF, new_pos) return (_POS_INF, new_pos) @@ -329,9 +337,9 @@ def _DoubleDecoder(): # If this value has all its exponent bits set and at least one significand # bit set, it's not a number. In Python 2.4, struct.unpack will treat it # as inf or -inf. To avoid that, we treat it specially. - if ((double_bytes[7] in '\x7F\xFF') - and (double_bytes[6] >= '\xF0') - and (double_bytes[0:7] != '\x00\x00\x00\x00\x00\x00\xF0')): + if ((double_bytes[7:8] in b'\x7F\xFF') + and (double_bytes[6:7] >= b'\xF0') + and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): return (_NAN, new_pos) # Note that we expect someone up-stack to catch struct.error and convert @@ -342,10 +350,86 @@ def _DoubleDecoder(): return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) +def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): + enum_type = key.enum_type + if is_packed: + local_DecodeVarint = _DecodeVarint + def DecodePackedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + (endpoint, pos) = local_DecodeVarint(buffer, pos) + endpoint += pos + if endpoint > end: + raise _DecodeError('Truncated message.') + while pos < endpoint: + value_start_pos = pos + (element, pos) = _DecodeSignedVarint32(buffer, pos) + if element in enum_type.values_by_number: + value.append(element) + else: + if not message._unknown_fields: + message._unknown_fields = [] + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_VARINT) + message._unknown_fields.append( + (tag_bytes, buffer[value_start_pos:pos])) + if pos > endpoint: + if element in enum_type.values_by_number: + del value[-1] # Discard corrupt value. + else: + del message._unknown_fields[-1] + raise _DecodeError('Packed element was truncated.') + return pos + return DecodePackedField + elif is_repeated: + tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) + tag_len = len(tag_bytes) + def DecodeRepeatedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + (element, new_pos) = _DecodeSignedVarint32(buffer, pos) + if element in enum_type.values_by_number: + value.append(element) + else: + if not message._unknown_fields: + message._unknown_fields = [] + message._unknown_fields.append( + (tag_bytes, buffer[pos:new_pos])) + # Predict that the next tag is another copy of the same repeated + # field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos >= end: + # Prediction failed. Return. + if new_pos > end: + raise _DecodeError('Truncated message.') + return new_pos + return DecodeRepeatedField + else: + def DecodeField(buffer, pos, end, message, field_dict): + value_start_pos = pos + (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) + if pos > end: + raise _DecodeError('Truncated message.') + if enum_value in enum_type.values_by_number: + field_dict[key] = enum_value + else: + if not message._unknown_fields: + message._unknown_fields = [] + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_VARINT) + message._unknown_fields.append( + (tag_bytes, buffer[value_start_pos:pos])) + return pos + return DecodeField + + # -------------------------------------------------------------------- -Int32Decoder = EnumDecoder = _SimpleDecoder( +Int32Decoder = _SimpleDecoder( wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) Int64Decoder = _SimpleDecoder( @@ -378,7 +462,15 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): """Returns a decoder for a string field.""" local_DecodeVarint = _DecodeVarint - local_unicode = unicode + local_unicode = six.text_type + + def _ConvertToUnicode(byte_str): + try: + return local_unicode(byte_str, 'utf-8') + except UnicodeDecodeError as e: + # add more information to the error message and re-raise it. + e.reason = '%s in field: %s' % (e, key.full_name) + raise assert not is_packed if is_repeated: @@ -394,7 +486,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') - value.append(local_unicode(buffer[pos:new_pos], 'utf-8')) + value.append(_ConvertToUnicode(buffer[pos:new_pos])) # Predict that the next tag is another copy of the same repeated field. pos = new_pos + tag_len if buffer[new_pos:pos] != tag_bytes or new_pos == end: @@ -407,7 +499,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') - field_dict[key] = local_unicode(buffer[pos:new_pos], 'utf-8') + field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) return new_pos return DecodeField @@ -511,9 +603,6 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): if value is None: value = field_dict.setdefault(key, new_default(message)) while 1: - value = field_dict.get(key) - if value is None: - value = field_dict.setdefault(key, new_default(message)) # Read length. (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size @@ -626,13 +715,59 @@ def MessageSetItemDecoder(extensions_by_number): return DecodeItem # -------------------------------------------------------------------- + +def MapDecoder(field_descriptor, new_default, is_message_map): + """Returns a decoder for a map field.""" + + key = field_descriptor + tag_bytes = encoder.TagBytes(field_descriptor.number, + wire_format.WIRETYPE_LENGTH_DELIMITED) + tag_len = len(tag_bytes) + local_DecodeVarint = _DecodeVarint + # Can't read _concrete_class yet; might not be initialized. + message_type = field_descriptor.message_type + + def DecodeMap(buffer, pos, end, message, field_dict): + submsg = message_type._concrete_class() + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + # Read length. + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated message.') + # Read sub-message. + submsg.Clear() + if submsg._InternalParse(buffer, pos, new_pos) != new_pos: + # The only reason _InternalParse would return early is if it + # encountered an end-group tag. + raise _DecodeError('Unexpected end-group tag.') + + if is_message_map: + value[submsg.key].MergeFrom(submsg.value) + else: + value[submsg.key] = submsg.value + + # Predict that the next tag is another copy of the same repeated field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos == end: + # Prediction failed. Return. + return new_pos + + return DecodeMap + +# -------------------------------------------------------------------- # Optimization is not as heavy here because calls to SkipField() are rare, # except for handling end-group tags. def _SkipVarint(buffer, pos, end): """Skip a varint value. Returns the new position.""" - - while ord(buffer[pos]) & 0x80: + # Previously ord(buffer[pos]) raised IndexError when pos is out of range. + # With this code, ord(b'') raises TypeError. Both are handled in + # python_message.py to generate a 'Truncated message' error. + while ord(buffer[pos:pos+1]) & 0x80: pos += 1 pos += 1 if pos > end: @@ -699,7 +834,6 @@ def _FieldSkipper(): ] wiretype_mask = wire_format.TAG_TYPE_MASK - local_ord = ord def SkipField(buffer, pos, end, tag_bytes): """Skips a field with the specified tag. @@ -712,7 +846,7 @@ def _FieldSkipper(): """ # The wire type is always in the first byte since varints are little-endian. - wire_type = local_ord(tag_bytes[0]) & wiretype_mask + wire_type = ord(tag_bytes[0:1]) & wiretype_mask return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) return SkipField |