aboutsummaryrefslogtreecommitdiff
path: root/generator/google/protobuf/internal/decoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'generator/google/protobuf/internal/decoder.py')
-rw-r--r--generator/google/protobuf/internal/decoder.py194
1 files changed, 30 insertions, 164 deletions
diff --git a/generator/google/protobuf/internal/decoder.py b/generator/google/protobuf/internal/decoder.py
index 31869e4..cb6f572 100644
--- a/generator/google/protobuf/internal/decoder.py
+++ b/generator/google/protobuf/internal/decoder.py
@@ -1,6 +1,6 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
-# https://developers.google.com/protocol-buffers/
+# http://code.google.com/p/protobuf/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
@@ -81,12 +81,6 @@ we repeatedly read a tag, look up the corresponding decoder, and invoke it.
__author__ = 'kenton@google.com (Kenton Varda)'
import struct
-
-import six
-
-if six.PY3:
- long = int
-
from google.protobuf.internal import encoder
from google.protobuf.internal import wire_format
from google.protobuf import message
@@ -104,7 +98,7 @@ _NAN = _POS_INF * 0
_DecodeError = message.DecodeError
-def _VarintDecoder(mask, result_type):
+def _VarintDecoder(mask):
"""Return an encoder for a basic varint value (does not include tag).
Decoded values will be bitwise-anded with the given mask before being
@@ -114,16 +108,16 @@ def _VarintDecoder(mask, result_type):
decoder returns a (value, new_pos) pair.
"""
+ local_ord = ord
def DecodeVarint(buffer, pos):
result = 0
shift = 0
while 1:
- b = six.indexbytes(buffer, pos)
+ b = local_ord(buffer[pos])
result |= ((b & 0x7f) << shift)
pos += 1
if not (b & 0x80):
result &= mask
- result = result_type(result)
return (result, pos)
shift += 7
if shift >= 64:
@@ -131,14 +125,15 @@ def _VarintDecoder(mask, result_type):
return DecodeVarint
-def _SignedVarintDecoder(mask, result_type):
+def _SignedVarintDecoder(mask):
"""Like _VarintDecoder() but decodes signed values."""
+ local_ord = ord
def DecodeVarint(buffer, pos):
result = 0
shift = 0
while 1:
- b = six.indexbytes(buffer, pos)
+ b = local_ord(buffer[pos])
result |= ((b & 0x7f) << shift)
pos += 1
if not (b & 0x80):
@@ -147,23 +142,19 @@ def _SignedVarintDecoder(mask, result_type):
result |= ~mask
else:
result &= mask
- result = result_type(result)
return (result, pos)
shift += 7
if shift >= 64:
raise _DecodeError('Too many bytes when decoding varint.')
return DecodeVarint
-# We force 32-bit values to int and 64-bit values to long to make
-# alternate implementations where the distinction is more significant
-# (e.g. the C++ implementation) simpler.
-_DecodeVarint = _VarintDecoder((1 << 64) - 1, long)
-_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1, long)
+_DecodeVarint = _VarintDecoder((1 << 64) - 1)
+_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1)
# Use these versions for values which must be limited to 32 bits.
-_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int)
-_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1, int)
+_DecodeVarint32 = _VarintDecoder((1 << 32) - 1)
+_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1)
def ReadTag(buffer, pos):
@@ -178,7 +169,7 @@ def ReadTag(buffer, pos):
"""
start = pos
- while six.indexbytes(buffer, pos) & 0x80:
+ while ord(buffer[pos]) & 0x80:
pos += 1
pos += 1
return (buffer[start:pos], pos)
@@ -303,12 +294,13 @@ def _FloatDecoder():
# If this value has all its exponent bits set, then it's non-finite.
# In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
# To avoid that, we parse it specially.
- if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'):
+ if ((float_bytes[3] in '\x7F\xFF')
+ and (float_bytes[2] >= '\x80')):
# If at least one significand bit is set...
- if float_bytes[0:3] != b'\x00\x00\x80':
+ if float_bytes[0:3] != '\x00\x00\x80':
return (_NAN, new_pos)
# If sign bit is set...
- if float_bytes[3:4] == b'\xFF':
+ if float_bytes[3] == '\xFF':
return (_NEG_INF, new_pos)
return (_POS_INF, new_pos)
@@ -337,9 +329,9 @@ def _DoubleDecoder():
# If this value has all its exponent bits set and at least one significand
# bit set, it's not a number. In Python 2.4, struct.unpack will treat it
# as inf or -inf. To avoid that, we treat it specially.
- if ((double_bytes[7:8] in b'\x7F\xFF')
- and (double_bytes[6:7] >= b'\xF0')
- and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
+ if ((double_bytes[7] in '\x7F\xFF')
+ and (double_bytes[6] >= '\xF0')
+ and (double_bytes[0:7] != '\x00\x00\x00\x00\x00\x00\xF0')):
return (_NAN, new_pos)
# Note that we expect someone up-stack to catch struct.error and convert
@@ -350,86 +342,10 @@ def _DoubleDecoder():
return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode)
-def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
- enum_type = key.enum_type
- if is_packed:
- local_DecodeVarint = _DecodeVarint
- def DecodePackedField(buffer, pos, end, message, field_dict):
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- (endpoint, pos) = local_DecodeVarint(buffer, pos)
- endpoint += pos
- if endpoint > end:
- raise _DecodeError('Truncated message.')
- while pos < endpoint:
- value_start_pos = pos
- (element, pos) = _DecodeSignedVarint32(buffer, pos)
- if element in enum_type.values_by_number:
- value.append(element)
- else:
- if not message._unknown_fields:
- message._unknown_fields = []
- tag_bytes = encoder.TagBytes(field_number,
- wire_format.WIRETYPE_VARINT)
- message._unknown_fields.append(
- (tag_bytes, buffer[value_start_pos:pos]))
- if pos > endpoint:
- if element in enum_type.values_by_number:
- del value[-1] # Discard corrupt value.
- else:
- del message._unknown_fields[-1]
- raise _DecodeError('Packed element was truncated.')
- return pos
- return DecodePackedField
- elif is_repeated:
- tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
- tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- while 1:
- (element, new_pos) = _DecodeSignedVarint32(buffer, pos)
- if element in enum_type.values_by_number:
- value.append(element)
- else:
- if not message._unknown_fields:
- message._unknown_fields = []
- message._unknown_fields.append(
- (tag_bytes, buffer[pos:new_pos]))
- # Predict that the next tag is another copy of the same repeated
- # field.
- pos = new_pos + tag_len
- if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
- # Prediction failed. Return.
- if new_pos > end:
- raise _DecodeError('Truncated message.')
- return new_pos
- return DecodeRepeatedField
- else:
- def DecodeField(buffer, pos, end, message, field_dict):
- value_start_pos = pos
- (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
- if pos > end:
- raise _DecodeError('Truncated message.')
- if enum_value in enum_type.values_by_number:
- field_dict[key] = enum_value
- else:
- if not message._unknown_fields:
- message._unknown_fields = []
- tag_bytes = encoder.TagBytes(field_number,
- wire_format.WIRETYPE_VARINT)
- message._unknown_fields.append(
- (tag_bytes, buffer[value_start_pos:pos]))
- return pos
- return DecodeField
-
-
# --------------------------------------------------------------------
-Int32Decoder = _SimpleDecoder(
+Int32Decoder = EnumDecoder = _SimpleDecoder(
wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
Int64Decoder = _SimpleDecoder(
@@ -462,15 +378,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
"""Returns a decoder for a string field."""
local_DecodeVarint = _DecodeVarint
- local_unicode = six.text_type
-
- def _ConvertToUnicode(byte_str):
- try:
- return local_unicode(byte_str, 'utf-8')
- except UnicodeDecodeError as e:
- # add more information to the error message and re-raise it.
- e.reason = '%s in field: %s' % (e, key.full_name)
- raise
+ local_unicode = unicode
assert not is_packed
if is_repeated:
@@ -486,7 +394,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated string.')
- value.append(_ConvertToUnicode(buffer[pos:new_pos]))
+ value.append(local_unicode(buffer[pos:new_pos], 'utf-8'))
# Predict that the next tag is another copy of the same repeated field.
pos = new_pos + tag_len
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
@@ -499,7 +407,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated string.')
- field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
+ field_dict[key] = local_unicode(buffer[pos:new_pos], 'utf-8')
return new_pos
return DecodeField
@@ -603,6 +511,9 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
if value is None:
value = field_dict.setdefault(key, new_default(message))
while 1:
+ value = field_dict.get(key)
+ if value is None:
+ value = field_dict.setdefault(key, new_default(message))
# Read length.
(size, pos) = local_DecodeVarint(buffer, pos)
new_pos = pos + size
@@ -715,59 +626,13 @@ def MessageSetItemDecoder(extensions_by_number):
return DecodeItem
# --------------------------------------------------------------------
-
-def MapDecoder(field_descriptor, new_default, is_message_map):
- """Returns a decoder for a map field."""
-
- key = field_descriptor
- tag_bytes = encoder.TagBytes(field_descriptor.number,
- wire_format.WIRETYPE_LENGTH_DELIMITED)
- tag_len = len(tag_bytes)
- local_DecodeVarint = _DecodeVarint
- # Can't read _concrete_class yet; might not be initialized.
- message_type = field_descriptor.message_type
-
- def DecodeMap(buffer, pos, end, message, field_dict):
- submsg = message_type._concrete_class()
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- while 1:
- # Read length.
- (size, pos) = local_DecodeVarint(buffer, pos)
- new_pos = pos + size
- if new_pos > end:
- raise _DecodeError('Truncated message.')
- # Read sub-message.
- submsg.Clear()
- if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
- # The only reason _InternalParse would return early is if it
- # encountered an end-group tag.
- raise _DecodeError('Unexpected end-group tag.')
-
- if is_message_map:
- value[submsg.key].MergeFrom(submsg.value)
- else:
- value[submsg.key] = submsg.value
-
- # Predict that the next tag is another copy of the same repeated field.
- pos = new_pos + tag_len
- if buffer[new_pos:pos] != tag_bytes or new_pos == end:
- # Prediction failed. Return.
- return new_pos
-
- return DecodeMap
-
-# --------------------------------------------------------------------
# Optimization is not as heavy here because calls to SkipField() are rare,
# except for handling end-group tags.
def _SkipVarint(buffer, pos, end):
"""Skip a varint value. Returns the new position."""
- # Previously ord(buffer[pos]) raised IndexError when pos is out of range.
- # With this code, ord(b'') raises TypeError. Both are handled in
- # python_message.py to generate a 'Truncated message' error.
- while ord(buffer[pos:pos+1]) & 0x80:
+
+ while ord(buffer[pos]) & 0x80:
pos += 1
pos += 1
if pos > end:
@@ -834,6 +699,7 @@ def _FieldSkipper():
]
wiretype_mask = wire_format.TAG_TYPE_MASK
+ local_ord = ord
def SkipField(buffer, pos, end, tag_bytes):
"""Skips a field with the specified tag.
@@ -846,7 +712,7 @@ def _FieldSkipper():
"""
# The wire type is always in the first byte since varints are little-endian.
- wire_type = ord(tag_bytes[0:1]) & wiretype_mask
+ wire_type = local_ord(tag_bytes[0]) & wiretype_mask
return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
return SkipField