diff options
Diffstat (limited to 'catapult/common/py_trace_event/third_party/protobuf/encoder.py')
-rw-r--r-- | catapult/common/py_trace_event/third_party/protobuf/encoder.py | 82 |
1 files changed, 82 insertions, 0 deletions
diff --git a/catapult/common/py_trace_event/third_party/protobuf/encoder.py b/catapult/common/py_trace_event/third_party/protobuf/encoder.py index 18aaccdc..50d10465 100644 --- a/catapult/common/py_trace_event/third_party/protobuf/encoder.py +++ b/catapult/common/py_trace_event/third_party/protobuf/encoder.py @@ -29,6 +29,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import six +import struct import wire_format @@ -155,12 +156,93 @@ def _SimpleEncoder(wire_type, encode_value, compute_value_size): return SpecificEncoder +def _FloatingPointEncoder(wire_type, format): + """Return a constructor for an encoder for float fields. + + This is like StructPackEncoder, but catches errors that may be due to + passing non-finite floating-point values to struct.pack, and makes a + second attempt to encode those values. + + Args: + wire_type: The field's wire type, for encoding tags. + format: The format string to pass to struct.pack(). + """ + + value_size = struct.calcsize(format) + if value_size == 4: + def EncodeNonFiniteOrRaise(write, value): + # Remember that the serialized form uses little-endian byte order. + if value == _POS_INF: + write(b'\x00\x00\x80\x7F') + elif value == _NEG_INF: + write(b'\x00\x00\x80\xFF') + elif value != value: # NaN + write(b'\x00\x00\xC0\x7F') + else: + raise + elif value_size == 8: + def EncodeNonFiniteOrRaise(write, value): + if value == _POS_INF: + write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F') + elif value == _NEG_INF: + write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF') + elif value != value: # NaN + write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F') + else: + raise + else: + raise ValueError('Can\'t encode floating-point values that are ' + '%d bytes long (only 4 or 8)' % value_size) + + def SpecificEncoder(field_number, is_repeated, is_packed): + local_struct_pack = struct.pack + if is_packed: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + def EncodePackedField(write, value): + write(tag_bytes) + local_EncodeVarint(write, len(value) * value_size) + for element in value: + # This try/except block is going to be faster than any code that + # we could write to check whether element is finite. + try: + write(local_struct_pack(format, element)) + except SystemError: + EncodeNonFiniteOrRaise(write, element) + return EncodePackedField + elif is_repeated: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeRepeatedField(write, value): + for element in value: + write(tag_bytes) + try: + write(local_struct_pack(format, element)) + except SystemError: + EncodeNonFiniteOrRaise(write, element) + return EncodeRepeatedField + else: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeField(write, value): + write(tag_bytes) + try: + write(local_struct_pack(format, value)) + except SystemError: + EncodeNonFiniteOrRaise(write, value) + return EncodeField + + return SpecificEncoder + + Int32Encoder = Int64Encoder = EnumEncoder = _SimpleEncoder( wire_format.WIRETYPE_VARINT, _EncodeSignedVarint, _SignedVarintSize) UInt32Encoder = UInt64Encoder = _SimpleEncoder( wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize) +FloatEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED32, '<f') + +DoubleEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED64, '<d') + def BoolEncoder(field_number, is_repeated, is_packed): """Returns an encoder for a boolean field.""" |