diff options
Diffstat (limited to 'generator/google/protobuf/internal/decoder.py')
-rw-r--r-- | generator/google/protobuf/internal/decoder.py | 194 |
1 files changed, 30 insertions, 164 deletions
diff --git a/generator/google/protobuf/internal/decoder.py b/generator/google/protobuf/internal/decoder.py index 31869e4..cb6f572 100644 --- a/generator/google/protobuf/internal/decoder.py +++ b/generator/google/protobuf/internal/decoder.py @@ -1,6 +1,6 @@ # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. -# https://developers.google.com/protocol-buffers/ +# http://code.google.com/p/protobuf/ # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -81,12 +81,6 @@ we repeatedly read a tag, look up the corresponding decoder, and invoke it. __author__ = 'kenton@google.com (Kenton Varda)' import struct - -import six - -if six.PY3: - long = int - from google.protobuf.internal import encoder from google.protobuf.internal import wire_format from google.protobuf import message @@ -104,7 +98,7 @@ _NAN = _POS_INF * 0 _DecodeError = message.DecodeError -def _VarintDecoder(mask, result_type): +def _VarintDecoder(mask): """Return an encoder for a basic varint value (does not include tag). Decoded values will be bitwise-anded with the given mask before being @@ -114,16 +108,16 @@ def _VarintDecoder(mask, result_type): decoder returns a (value, new_pos) pair. """ + local_ord = ord def DecodeVarint(buffer, pos): result = 0 shift = 0 while 1: - b = six.indexbytes(buffer, pos) + b = local_ord(buffer[pos]) result |= ((b & 0x7f) << shift) pos += 1 if not (b & 0x80): result &= mask - result = result_type(result) return (result, pos) shift += 7 if shift >= 64: @@ -131,14 +125,15 @@ def _VarintDecoder(mask, result_type): return DecodeVarint -def _SignedVarintDecoder(mask, result_type): +def _SignedVarintDecoder(mask): """Like _VarintDecoder() but decodes signed values.""" + local_ord = ord def DecodeVarint(buffer, pos): result = 0 shift = 0 while 1: - b = six.indexbytes(buffer, pos) + b = local_ord(buffer[pos]) result |= ((b & 0x7f) << shift) pos += 1 if not (b & 0x80): @@ -147,23 +142,19 @@ def _SignedVarintDecoder(mask, result_type): result |= ~mask else: result &= mask - result = result_type(result) return (result, pos) shift += 7 if shift >= 64: raise _DecodeError('Too many bytes when decoding varint.') return DecodeVarint -# We force 32-bit values to int and 64-bit values to long to make -# alternate implementations where the distinction is more significant -# (e.g. the C++ implementation) simpler. -_DecodeVarint = _VarintDecoder((1 << 64) - 1, long) -_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1, long) +_DecodeVarint = _VarintDecoder((1 << 64) - 1) +_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1) # Use these versions for values which must be limited to 32 bits. -_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) -_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1, int) +_DecodeVarint32 = _VarintDecoder((1 << 32) - 1) +_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1) def ReadTag(buffer, pos): @@ -178,7 +169,7 @@ def ReadTag(buffer, pos): """ start = pos - while six.indexbytes(buffer, pos) & 0x80: + while ord(buffer[pos]) & 0x80: pos += 1 pos += 1 return (buffer[start:pos], pos) @@ -303,12 +294,13 @@ def _FloatDecoder(): # If this value has all its exponent bits set, then it's non-finite. # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. # To avoid that, we parse it specially. - if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'): + if ((float_bytes[3] in '\x7F\xFF') + and (float_bytes[2] >= '\x80')): # If at least one significand bit is set... - if float_bytes[0:3] != b'\x00\x00\x80': + if float_bytes[0:3] != '\x00\x00\x80': return (_NAN, new_pos) # If sign bit is set... - if float_bytes[3:4] == b'\xFF': + if float_bytes[3] == '\xFF': return (_NEG_INF, new_pos) return (_POS_INF, new_pos) @@ -337,9 +329,9 @@ def _DoubleDecoder(): # If this value has all its exponent bits set and at least one significand # bit set, it's not a number. In Python 2.4, struct.unpack will treat it # as inf or -inf. To avoid that, we treat it specially. - if ((double_bytes[7:8] in b'\x7F\xFF') - and (double_bytes[6:7] >= b'\xF0') - and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): + if ((double_bytes[7] in '\x7F\xFF') + and (double_bytes[6] >= '\xF0') + and (double_bytes[0:7] != '\x00\x00\x00\x00\x00\x00\xF0')): return (_NAN, new_pos) # Note that we expect someone up-stack to catch struct.error and convert @@ -350,86 +342,10 @@ def _DoubleDecoder(): return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) -def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): - enum_type = key.enum_type - if is_packed: - local_DecodeVarint = _DecodeVarint - def DecodePackedField(buffer, pos, end, message, field_dict): - value = field_dict.get(key) - if value is None: - value = field_dict.setdefault(key, new_default(message)) - (endpoint, pos) = local_DecodeVarint(buffer, pos) - endpoint += pos - if endpoint > end: - raise _DecodeError('Truncated message.') - while pos < endpoint: - value_start_pos = pos - (element, pos) = _DecodeSignedVarint32(buffer, pos) - if element in enum_type.values_by_number: - value.append(element) - else: - if not message._unknown_fields: - message._unknown_fields = [] - tag_bytes = encoder.TagBytes(field_number, - wire_format.WIRETYPE_VARINT) - message._unknown_fields.append( - (tag_bytes, buffer[value_start_pos:pos])) - if pos > endpoint: - if element in enum_type.values_by_number: - del value[-1] # Discard corrupt value. - else: - del message._unknown_fields[-1] - raise _DecodeError('Packed element was truncated.') - return pos - return DecodePackedField - elif is_repeated: - tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) - tag_len = len(tag_bytes) - def DecodeRepeatedField(buffer, pos, end, message, field_dict): - value = field_dict.get(key) - if value is None: - value = field_dict.setdefault(key, new_default(message)) - while 1: - (element, new_pos) = _DecodeSignedVarint32(buffer, pos) - if element in enum_type.values_by_number: - value.append(element) - else: - if not message._unknown_fields: - message._unknown_fields = [] - message._unknown_fields.append( - (tag_bytes, buffer[pos:new_pos])) - # Predict that the next tag is another copy of the same repeated - # field. - pos = new_pos + tag_len - if buffer[new_pos:pos] != tag_bytes or new_pos >= end: - # Prediction failed. Return. - if new_pos > end: - raise _DecodeError('Truncated message.') - return new_pos - return DecodeRepeatedField - else: - def DecodeField(buffer, pos, end, message, field_dict): - value_start_pos = pos - (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) - if pos > end: - raise _DecodeError('Truncated message.') - if enum_value in enum_type.values_by_number: - field_dict[key] = enum_value - else: - if not message._unknown_fields: - message._unknown_fields = [] - tag_bytes = encoder.TagBytes(field_number, - wire_format.WIRETYPE_VARINT) - message._unknown_fields.append( - (tag_bytes, buffer[value_start_pos:pos])) - return pos - return DecodeField - - # -------------------------------------------------------------------- -Int32Decoder = _SimpleDecoder( +Int32Decoder = EnumDecoder = _SimpleDecoder( wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) Int64Decoder = _SimpleDecoder( @@ -462,15 +378,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): """Returns a decoder for a string field.""" local_DecodeVarint = _DecodeVarint - local_unicode = six.text_type - - def _ConvertToUnicode(byte_str): - try: - return local_unicode(byte_str, 'utf-8') - except UnicodeDecodeError as e: - # add more information to the error message and re-raise it. - e.reason = '%s in field: %s' % (e, key.full_name) - raise + local_unicode = unicode assert not is_packed if is_repeated: @@ -486,7 +394,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') - value.append(_ConvertToUnicode(buffer[pos:new_pos])) + value.append(local_unicode(buffer[pos:new_pos], 'utf-8')) # Predict that the next tag is another copy of the same repeated field. pos = new_pos + tag_len if buffer[new_pos:pos] != tag_bytes or new_pos == end: @@ -499,7 +407,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') - field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) + field_dict[key] = local_unicode(buffer[pos:new_pos], 'utf-8') return new_pos return DecodeField @@ -603,6 +511,9 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): if value is None: value = field_dict.setdefault(key, new_default(message)) while 1: + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) # Read length. (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size @@ -715,59 +626,13 @@ def MessageSetItemDecoder(extensions_by_number): return DecodeItem # -------------------------------------------------------------------- - -def MapDecoder(field_descriptor, new_default, is_message_map): - """Returns a decoder for a map field.""" - - key = field_descriptor - tag_bytes = encoder.TagBytes(field_descriptor.number, - wire_format.WIRETYPE_LENGTH_DELIMITED) - tag_len = len(tag_bytes) - local_DecodeVarint = _DecodeVarint - # Can't read _concrete_class yet; might not be initialized. - message_type = field_descriptor.message_type - - def DecodeMap(buffer, pos, end, message, field_dict): - submsg = message_type._concrete_class() - value = field_dict.get(key) - if value is None: - value = field_dict.setdefault(key, new_default(message)) - while 1: - # Read length. - (size, pos) = local_DecodeVarint(buffer, pos) - new_pos = pos + size - if new_pos > end: - raise _DecodeError('Truncated message.') - # Read sub-message. - submsg.Clear() - if submsg._InternalParse(buffer, pos, new_pos) != new_pos: - # The only reason _InternalParse would return early is if it - # encountered an end-group tag. - raise _DecodeError('Unexpected end-group tag.') - - if is_message_map: - value[submsg.key].MergeFrom(submsg.value) - else: - value[submsg.key] = submsg.value - - # Predict that the next tag is another copy of the same repeated field. - pos = new_pos + tag_len - if buffer[new_pos:pos] != tag_bytes or new_pos == end: - # Prediction failed. Return. - return new_pos - - return DecodeMap - -# -------------------------------------------------------------------- # Optimization is not as heavy here because calls to SkipField() are rare, # except for handling end-group tags. def _SkipVarint(buffer, pos, end): """Skip a varint value. Returns the new position.""" - # Previously ord(buffer[pos]) raised IndexError when pos is out of range. - # With this code, ord(b'') raises TypeError. Both are handled in - # python_message.py to generate a 'Truncated message' error. - while ord(buffer[pos:pos+1]) & 0x80: + + while ord(buffer[pos]) & 0x80: pos += 1 pos += 1 if pos > end: @@ -834,6 +699,7 @@ def _FieldSkipper(): ] wiretype_mask = wire_format.TAG_TYPE_MASK + local_ord = ord def SkipField(buffer, pos, end, tag_bytes): """Skips a field with the specified tag. @@ -846,7 +712,7 @@ def _FieldSkipper(): """ # The wire type is always in the first byte since varints are little-endian. - wire_type = ord(tag_bytes[0:1]) & wiretype_mask + wire_type = local_ord(tag_bytes[0]) & wiretype_mask return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) return SkipField |