aboutsummaryrefslogtreecommitdiff
path: root/catapult/common/py_trace_event/third_party/protobuf/encoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'catapult/common/py_trace_event/third_party/protobuf/encoder.py')
-rw-r--r--catapult/common/py_trace_event/third_party/protobuf/encoder.py82
1 files changed, 82 insertions, 0 deletions
diff --git a/catapult/common/py_trace_event/third_party/protobuf/encoder.py b/catapult/common/py_trace_event/third_party/protobuf/encoder.py
index 18aaccdc..50d10465 100644
--- a/catapult/common/py_trace_event/third_party/protobuf/encoder.py
+++ b/catapult/common/py_trace_event/third_party/protobuf/encoder.py
@@ -29,6 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import six
+import struct
import wire_format
@@ -155,12 +156,93 @@ def _SimpleEncoder(wire_type, encode_value, compute_value_size):
return SpecificEncoder
+def _FloatingPointEncoder(wire_type, format):
+ """Return a constructor for an encoder for float fields.
+
+ This is like StructPackEncoder, but catches errors that may be due to
+ passing non-finite floating-point values to struct.pack, and makes a
+ second attempt to encode those values.
+
+ Args:
+ wire_type: The field's wire type, for encoding tags.
+ format: The format string to pass to struct.pack().
+ """
+
+ value_size = struct.calcsize(format)
+ if value_size == 4:
+ def EncodeNonFiniteOrRaise(write, value):
+ # Remember that the serialized form uses little-endian byte order.
+ if value == _POS_INF:
+ write(b'\x00\x00\x80\x7F')
+ elif value == _NEG_INF:
+ write(b'\x00\x00\x80\xFF')
+ elif value != value: # NaN
+ write(b'\x00\x00\xC0\x7F')
+ else:
+ raise
+ elif value_size == 8:
+ def EncodeNonFiniteOrRaise(write, value):
+ if value == _POS_INF:
+ write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F')
+ elif value == _NEG_INF:
+ write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF')
+ elif value != value: # NaN
+ write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F')
+ else:
+ raise
+ else:
+ raise ValueError('Can\'t encode floating-point values that are '
+ '%d bytes long (only 4 or 8)' % value_size)
+
+ def SpecificEncoder(field_number, is_repeated, is_packed):
+ local_struct_pack = struct.pack
+ if is_packed:
+ tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ local_EncodeVarint = _EncodeVarint
+ def EncodePackedField(write, value):
+ write(tag_bytes)
+ local_EncodeVarint(write, len(value) * value_size)
+ for element in value:
+ # This try/except block is going to be faster than any code that
+ # we could write to check whether element is finite.
+ try:
+ write(local_struct_pack(format, element))
+ except SystemError:
+ EncodeNonFiniteOrRaise(write, element)
+ return EncodePackedField
+ elif is_repeated:
+ tag_bytes = TagBytes(field_number, wire_type)
+ def EncodeRepeatedField(write, value):
+ for element in value:
+ write(tag_bytes)
+ try:
+ write(local_struct_pack(format, element))
+ except SystemError:
+ EncodeNonFiniteOrRaise(write, element)
+ return EncodeRepeatedField
+ else:
+ tag_bytes = TagBytes(field_number, wire_type)
+ def EncodeField(write, value):
+ write(tag_bytes)
+ try:
+ write(local_struct_pack(format, value))
+ except SystemError:
+ EncodeNonFiniteOrRaise(write, value)
+ return EncodeField
+
+ return SpecificEncoder
+
+
Int32Encoder = Int64Encoder = EnumEncoder = _SimpleEncoder(
wire_format.WIRETYPE_VARINT, _EncodeSignedVarint, _SignedVarintSize)
UInt32Encoder = UInt64Encoder = _SimpleEncoder(
wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize)
+FloatEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED32, '<f')
+
+DoubleEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED64, '<d')
+
def BoolEncoder(field_number, is_repeated, is_packed):
"""Returns an encoder for a boolean field."""