diff options
Diffstat (limited to 'generator/google/protobuf/internal/message_test.py')
-rw-r--r-- | generator/google/protobuf/internal/message_test.py | 1856 |
1 files changed, 0 insertions, 1856 deletions
diff --git a/generator/google/protobuf/internal/message_test.py b/generator/google/protobuf/internal/message_test.py deleted file mode 100644 index 1e95adf..0000000 --- a/generator/google/protobuf/internal/message_test.py +++ /dev/null @@ -1,1856 +0,0 @@ -#! /usr/bin/env python -# -# Protocol Buffers - Google's data interchange format -# Copyright 2008 Google Inc. All rights reserved. -# https://developers.google.com/protocol-buffers/ -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -"""Tests python protocol buffers against the golden message. - -Note that the golden messages exercise every known field type, thus this -test ends up exercising and verifying nearly all of the parsing and -serialization code in the whole library. - -TODO(kenton): Merge with wire_format_test? It doesn't make a whole lot of -sense to call this a test of the "message" module, which only declares an -abstract interface. -""" - -__author__ = 'gps@google.com (Gregory P. Smith)' - - -import collections -import copy -import math -import operator -import pickle -import six -import sys - -try: - import unittest2 as unittest #PY26 -except ImportError: - import unittest - -from google.protobuf import map_unittest_pb2 -from google.protobuf import unittest_pb2 -from google.protobuf import unittest_proto3_arena_pb2 -from google.protobuf import descriptor_pb2 -from google.protobuf import descriptor_pool -from google.protobuf import message_factory -from google.protobuf import text_format -from google.protobuf.internal import api_implementation -from google.protobuf.internal import packed_field_test_pb2 -from google.protobuf.internal import test_util -from google.protobuf import message -from google.protobuf.internal import _parameterized - -if six.PY3: - long = int - - -# Python pre-2.6 does not have isinf() or isnan() functions, so we have -# to provide our own. -def isnan(val): - # NaN is never equal to itself. - return val != val -def isinf(val): - # Infinity times zero equals NaN. - return not isnan(val) and isnan(val * 0) -def IsPosInf(val): - return isinf(val) and (val > 0) -def IsNegInf(val): - return isinf(val) and (val < 0) - - -@_parameterized.Parameters( - (unittest_pb2), - (unittest_proto3_arena_pb2)) -class MessageTest(unittest.TestCase): - - def testBadUtf8String(self, message_module): - if api_implementation.Type() != 'python': - self.skipTest("Skipping testBadUtf8String, currently only the python " - "api implementation raises UnicodeDecodeError when a " - "string field contains bad utf-8.") - bad_utf8_data = test_util.GoldenFileData('bad_utf8_string') - with self.assertRaises(UnicodeDecodeError) as context: - message_module.TestAllTypes.FromString(bad_utf8_data) - self.assertIn('TestAllTypes.optional_string', str(context.exception)) - - def testGoldenMessage(self, message_module): - # Proto3 doesn't have the "default_foo" members or foreign enums, - # and doesn't preserve unknown fields, so for proto3 we use a golden - # message that doesn't have these fields set. - if message_module is unittest_pb2: - golden_data = test_util.GoldenFileData( - 'golden_message_oneof_implemented') - else: - golden_data = test_util.GoldenFileData('golden_message_proto3') - - golden_message = message_module.TestAllTypes() - golden_message.ParseFromString(golden_data) - if message_module is unittest_pb2: - test_util.ExpectAllFieldsSet(self, golden_message) - self.assertEqual(golden_data, golden_message.SerializeToString()) - golden_copy = copy.deepcopy(golden_message) - self.assertEqual(golden_data, golden_copy.SerializeToString()) - - def testGoldenPackedMessage(self, message_module): - golden_data = test_util.GoldenFileData('golden_packed_fields_message') - golden_message = message_module.TestPackedTypes() - golden_message.ParseFromString(golden_data) - all_set = message_module.TestPackedTypes() - test_util.SetAllPackedFields(all_set) - self.assertEqual(all_set, golden_message) - self.assertEqual(golden_data, all_set.SerializeToString()) - golden_copy = copy.deepcopy(golden_message) - self.assertEqual(golden_data, golden_copy.SerializeToString()) - - def testPickleSupport(self, message_module): - golden_data = test_util.GoldenFileData('golden_message') - golden_message = message_module.TestAllTypes() - golden_message.ParseFromString(golden_data) - pickled_message = pickle.dumps(golden_message) - - unpickled_message = pickle.loads(pickled_message) - self.assertEqual(unpickled_message, golden_message) - - def testPositiveInfinity(self, message_module): - if message_module is unittest_pb2: - golden_data = (b'\x5D\x00\x00\x80\x7F' - b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' - b'\xCD\x02\x00\x00\x80\x7F' - b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F') - else: - golden_data = (b'\x5D\x00\x00\x80\x7F' - b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' - b'\xCA\x02\x04\x00\x00\x80\x7F' - b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') - - golden_message = message_module.TestAllTypes() - golden_message.ParseFromString(golden_data) - self.assertTrue(IsPosInf(golden_message.optional_float)) - self.assertTrue(IsPosInf(golden_message.optional_double)) - self.assertTrue(IsPosInf(golden_message.repeated_float[0])) - self.assertTrue(IsPosInf(golden_message.repeated_double[0])) - self.assertEqual(golden_data, golden_message.SerializeToString()) - - def testNegativeInfinity(self, message_module): - if message_module is unittest_pb2: - golden_data = (b'\x5D\x00\x00\x80\xFF' - b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' - b'\xCD\x02\x00\x00\x80\xFF' - b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF') - else: - golden_data = (b'\x5D\x00\x00\x80\xFF' - b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' - b'\xCA\x02\x04\x00\x00\x80\xFF' - b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') - - golden_message = message_module.TestAllTypes() - golden_message.ParseFromString(golden_data) - self.assertTrue(IsNegInf(golden_message.optional_float)) - self.assertTrue(IsNegInf(golden_message.optional_double)) - self.assertTrue(IsNegInf(golden_message.repeated_float[0])) - self.assertTrue(IsNegInf(golden_message.repeated_double[0])) - self.assertEqual(golden_data, golden_message.SerializeToString()) - - def testNotANumber(self, message_module): - golden_data = (b'\x5D\x00\x00\xC0\x7F' - b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F' - b'\xCD\x02\x00\x00\xC0\x7F' - b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F') - golden_message = message_module.TestAllTypes() - golden_message.ParseFromString(golden_data) - self.assertTrue(isnan(golden_message.optional_float)) - self.assertTrue(isnan(golden_message.optional_double)) - self.assertTrue(isnan(golden_message.repeated_float[0])) - self.assertTrue(isnan(golden_message.repeated_double[0])) - - # The protocol buffer may serialize to any one of multiple different - # representations of a NaN. Rather than verify a specific representation, - # verify the serialized string can be converted into a correctly - # behaving protocol buffer. - serialized = golden_message.SerializeToString() - message = message_module.TestAllTypes() - message.ParseFromString(serialized) - self.assertTrue(isnan(message.optional_float)) - self.assertTrue(isnan(message.optional_double)) - self.assertTrue(isnan(message.repeated_float[0])) - self.assertTrue(isnan(message.repeated_double[0])) - - def testPositiveInfinityPacked(self, message_module): - golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F' - b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') - golden_message = message_module.TestPackedTypes() - golden_message.ParseFromString(golden_data) - self.assertTrue(IsPosInf(golden_message.packed_float[0])) - self.assertTrue(IsPosInf(golden_message.packed_double[0])) - self.assertEqual(golden_data, golden_message.SerializeToString()) - - def testNegativeInfinityPacked(self, message_module): - golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF' - b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') - golden_message = message_module.TestPackedTypes() - golden_message.ParseFromString(golden_data) - self.assertTrue(IsNegInf(golden_message.packed_float[0])) - self.assertTrue(IsNegInf(golden_message.packed_double[0])) - self.assertEqual(golden_data, golden_message.SerializeToString()) - - def testNotANumberPacked(self, message_module): - golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F' - b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F') - golden_message = message_module.TestPackedTypes() - golden_message.ParseFromString(golden_data) - self.assertTrue(isnan(golden_message.packed_float[0])) - self.assertTrue(isnan(golden_message.packed_double[0])) - - serialized = golden_message.SerializeToString() - message = message_module.TestPackedTypes() - message.ParseFromString(serialized) - self.assertTrue(isnan(message.packed_float[0])) - self.assertTrue(isnan(message.packed_double[0])) - - def testExtremeFloatValues(self, message_module): - message = message_module.TestAllTypes() - - # Most positive exponent, no significand bits set. - kMostPosExponentNoSigBits = math.pow(2, 127) - message.optional_float = kMostPosExponentNoSigBits - message.ParseFromString(message.SerializeToString()) - self.assertTrue(message.optional_float == kMostPosExponentNoSigBits) - - # Most positive exponent, one significand bit set. - kMostPosExponentOneSigBit = 1.5 * math.pow(2, 127) - message.optional_float = kMostPosExponentOneSigBit - message.ParseFromString(message.SerializeToString()) - self.assertTrue(message.optional_float == kMostPosExponentOneSigBit) - - # Repeat last two cases with values of same magnitude, but negative. - message.optional_float = -kMostPosExponentNoSigBits - message.ParseFromString(message.SerializeToString()) - self.assertTrue(message.optional_float == -kMostPosExponentNoSigBits) - - message.optional_float = -kMostPosExponentOneSigBit - message.ParseFromString(message.SerializeToString()) - self.assertTrue(message.optional_float == -kMostPosExponentOneSigBit) - - # Most negative exponent, no significand bits set. - kMostNegExponentNoSigBits = math.pow(2, -127) - message.optional_float = kMostNegExponentNoSigBits - message.ParseFromString(message.SerializeToString()) - self.assertTrue(message.optional_float == kMostNegExponentNoSigBits) - - # Most negative exponent, one significand bit set. - kMostNegExponentOneSigBit = 1.5 * math.pow(2, -127) - message.optional_float = kMostNegExponentOneSigBit - message.ParseFromString(message.SerializeToString()) - self.assertTrue(message.optional_float == kMostNegExponentOneSigBit) - - # Repeat last two cases with values of the same magnitude, but negative. - message.optional_float = -kMostNegExponentNoSigBits - message.ParseFromString(message.SerializeToString()) - self.assertTrue(message.optional_float == -kMostNegExponentNoSigBits) - - message.optional_float = -kMostNegExponentOneSigBit - message.ParseFromString(message.SerializeToString()) - self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit) - - def testExtremeDoubleValues(self, message_module): - message = message_module.TestAllTypes() - - # Most positive exponent, no significand bits set. - kMostPosExponentNoSigBits = math.pow(2, 1023) - message.optional_double = kMostPosExponentNoSigBits - message.ParseFromString(message.SerializeToString()) - self.assertTrue(message.optional_double == kMostPosExponentNoSigBits) - - # Most positive exponent, one significand bit set. - kMostPosExponentOneSigBit = 1.5 * math.pow(2, 1023) - message.optional_double = kMostPosExponentOneSigBit - message.ParseFromString(message.SerializeToString()) - self.assertTrue(message.optional_double == kMostPosExponentOneSigBit) - - # Repeat last two cases with values of same magnitude, but negative. - message.optional_double = -kMostPosExponentNoSigBits - message.ParseFromString(message.SerializeToString()) - self.assertTrue(message.optional_double == -kMostPosExponentNoSigBits) - - message.optional_double = -kMostPosExponentOneSigBit - message.ParseFromString(message.SerializeToString()) - self.assertTrue(message.optional_double == -kMostPosExponentOneSigBit) - - # Most negative exponent, no significand bits set. - kMostNegExponentNoSigBits = math.pow(2, -1023) - message.optional_double = kMostNegExponentNoSigBits - message.ParseFromString(message.SerializeToString()) - self.assertTrue(message.optional_double == kMostNegExponentNoSigBits) - - # Most negative exponent, one significand bit set. - kMostNegExponentOneSigBit = 1.5 * math.pow(2, -1023) - message.optional_double = kMostNegExponentOneSigBit - message.ParseFromString(message.SerializeToString()) - self.assertTrue(message.optional_double == kMostNegExponentOneSigBit) - - # Repeat last two cases with values of the same magnitude, but negative. - message.optional_double = -kMostNegExponentNoSigBits - message.ParseFromString(message.SerializeToString()) - self.assertTrue(message.optional_double == -kMostNegExponentNoSigBits) - - message.optional_double = -kMostNegExponentOneSigBit - message.ParseFromString(message.SerializeToString()) - self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit) - - def testFloatPrinting(self, message_module): - message = message_module.TestAllTypes() - message.optional_float = 2.0 - self.assertEqual(str(message), 'optional_float: 2.0\n') - - def testHighPrecisionFloatPrinting(self, message_module): - message = message_module.TestAllTypes() - message.optional_double = 0.12345678912345678 - if sys.version_info >= (3,): - self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n') - else: - self.assertEqual(str(message), 'optional_double: 0.123456789123\n') - - def testUnknownFieldPrinting(self, message_module): - populated = message_module.TestAllTypes() - test_util.SetAllNonLazyFields(populated) - empty = message_module.TestEmptyMessage() - empty.ParseFromString(populated.SerializeToString()) - self.assertEqual(str(empty), '') - - def testRepeatedNestedFieldIteration(self, message_module): - msg = message_module.TestAllTypes() - msg.repeated_nested_message.add(bb=1) - msg.repeated_nested_message.add(bb=2) - msg.repeated_nested_message.add(bb=3) - msg.repeated_nested_message.add(bb=4) - - self.assertEqual([1, 2, 3, 4], - [m.bb for m in msg.repeated_nested_message]) - self.assertEqual([4, 3, 2, 1], - [m.bb for m in reversed(msg.repeated_nested_message)]) - self.assertEqual([4, 3, 2, 1], - [m.bb for m in msg.repeated_nested_message[::-1]]) - - def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module): - """Check some different types with the default comparator.""" - message = message_module.TestAllTypes() - - # TODO(mattp): would testing more scalar types strengthen test? - message.repeated_int32.append(1) - message.repeated_int32.append(3) - message.repeated_int32.append(2) - message.repeated_int32.sort() - self.assertEqual(message.repeated_int32[0], 1) - self.assertEqual(message.repeated_int32[1], 2) - self.assertEqual(message.repeated_int32[2], 3) - - message.repeated_float.append(1.1) - message.repeated_float.append(1.3) - message.repeated_float.append(1.2) - message.repeated_float.sort() - self.assertAlmostEqual(message.repeated_float[0], 1.1) - self.assertAlmostEqual(message.repeated_float[1], 1.2) - self.assertAlmostEqual(message.repeated_float[2], 1.3) - - message.repeated_string.append('a') - message.repeated_string.append('c') - message.repeated_string.append('b') - message.repeated_string.sort() - self.assertEqual(message.repeated_string[0], 'a') - self.assertEqual(message.repeated_string[1], 'b') - self.assertEqual(message.repeated_string[2], 'c') - - message.repeated_bytes.append(b'a') - message.repeated_bytes.append(b'c') - message.repeated_bytes.append(b'b') - message.repeated_bytes.sort() - self.assertEqual(message.repeated_bytes[0], b'a') - self.assertEqual(message.repeated_bytes[1], b'b') - self.assertEqual(message.repeated_bytes[2], b'c') - - def testSortingRepeatedScalarFieldsCustomComparator(self, message_module): - """Check some different types with custom comparator.""" - message = message_module.TestAllTypes() - - message.repeated_int32.append(-3) - message.repeated_int32.append(-2) - message.repeated_int32.append(-1) - message.repeated_int32.sort(key=abs) - self.assertEqual(message.repeated_int32[0], -1) - self.assertEqual(message.repeated_int32[1], -2) - self.assertEqual(message.repeated_int32[2], -3) - - message.repeated_string.append('aaa') - message.repeated_string.append('bb') - message.repeated_string.append('c') - message.repeated_string.sort(key=len) - self.assertEqual(message.repeated_string[0], 'c') - self.assertEqual(message.repeated_string[1], 'bb') - self.assertEqual(message.repeated_string[2], 'aaa') - - def testSortingRepeatedCompositeFieldsCustomComparator(self, message_module): - """Check passing a custom comparator to sort a repeated composite field.""" - message = message_module.TestAllTypes() - - message.repeated_nested_message.add().bb = 1 - message.repeated_nested_message.add().bb = 3 - message.repeated_nested_message.add().bb = 2 - message.repeated_nested_message.add().bb = 6 - message.repeated_nested_message.add().bb = 5 - message.repeated_nested_message.add().bb = 4 - message.repeated_nested_message.sort(key=operator.attrgetter('bb')) - self.assertEqual(message.repeated_nested_message[0].bb, 1) - self.assertEqual(message.repeated_nested_message[1].bb, 2) - self.assertEqual(message.repeated_nested_message[2].bb, 3) - self.assertEqual(message.repeated_nested_message[3].bb, 4) - self.assertEqual(message.repeated_nested_message[4].bb, 5) - self.assertEqual(message.repeated_nested_message[5].bb, 6) - - def testSortingRepeatedCompositeFieldsStable(self, message_module): - """Check passing a custom comparator to sort a repeated composite field.""" - message = message_module.TestAllTypes() - - message.repeated_nested_message.add().bb = 21 - message.repeated_nested_message.add().bb = 20 - message.repeated_nested_message.add().bb = 13 - message.repeated_nested_message.add().bb = 33 - message.repeated_nested_message.add().bb = 11 - message.repeated_nested_message.add().bb = 24 - message.repeated_nested_message.add().bb = 10 - message.repeated_nested_message.sort(key=lambda z: z.bb // 10) - self.assertEqual( - [13, 11, 10, 21, 20, 24, 33], - [n.bb for n in message.repeated_nested_message]) - - # Make sure that for the C++ implementation, the underlying fields - # are actually reordered. - pb = message.SerializeToString() - message.Clear() - message.MergeFromString(pb) - self.assertEqual( - [13, 11, 10, 21, 20, 24, 33], - [n.bb for n in message.repeated_nested_message]) - - def testRepeatedCompositeFieldSortArguments(self, message_module): - """Check sorting a repeated composite field using list.sort() arguments.""" - message = message_module.TestAllTypes() - - get_bb = operator.attrgetter('bb') - cmp_bb = lambda a, b: cmp(a.bb, b.bb) - message.repeated_nested_message.add().bb = 1 - message.repeated_nested_message.add().bb = 3 - message.repeated_nested_message.add().bb = 2 - message.repeated_nested_message.add().bb = 6 - message.repeated_nested_message.add().bb = 5 - message.repeated_nested_message.add().bb = 4 - message.repeated_nested_message.sort(key=get_bb) - self.assertEqual([k.bb for k in message.repeated_nested_message], - [1, 2, 3, 4, 5, 6]) - message.repeated_nested_message.sort(key=get_bb, reverse=True) - self.assertEqual([k.bb for k in message.repeated_nested_message], - [6, 5, 4, 3, 2, 1]) - if sys.version_info >= (3,): return # No cmp sorting in PY3. - message.repeated_nested_message.sort(sort_function=cmp_bb) - self.assertEqual([k.bb for k in message.repeated_nested_message], - [1, 2, 3, 4, 5, 6]) - message.repeated_nested_message.sort(cmp=cmp_bb, reverse=True) - self.assertEqual([k.bb for k in message.repeated_nested_message], - [6, 5, 4, 3, 2, 1]) - - def testRepeatedScalarFieldSortArguments(self, message_module): - """Check sorting a scalar field using list.sort() arguments.""" - message = message_module.TestAllTypes() - - message.repeated_int32.append(-3) - message.repeated_int32.append(-2) - message.repeated_int32.append(-1) - message.repeated_int32.sort(key=abs) - self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) - message.repeated_int32.sort(key=abs, reverse=True) - self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) - if sys.version_info < (3,): # No cmp sorting in PY3. - abs_cmp = lambda a, b: cmp(abs(a), abs(b)) - message.repeated_int32.sort(sort_function=abs_cmp) - self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) - message.repeated_int32.sort(cmp=abs_cmp, reverse=True) - self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) - - message.repeated_string.append('aaa') - message.repeated_string.append('bb') - message.repeated_string.append('c') - message.repeated_string.sort(key=len) - self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) - message.repeated_string.sort(key=len, reverse=True) - self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) - if sys.version_info < (3,): # No cmp sorting in PY3. - len_cmp = lambda a, b: cmp(len(a), len(b)) - message.repeated_string.sort(sort_function=len_cmp) - self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) - message.repeated_string.sort(cmp=len_cmp, reverse=True) - self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) - - def testRepeatedFieldsComparable(self, message_module): - m1 = message_module.TestAllTypes() - m2 = message_module.TestAllTypes() - m1.repeated_int32.append(0) - m1.repeated_int32.append(1) - m1.repeated_int32.append(2) - m2.repeated_int32.append(0) - m2.repeated_int32.append(1) - m2.repeated_int32.append(2) - m1.repeated_nested_message.add().bb = 1 - m1.repeated_nested_message.add().bb = 2 - m1.repeated_nested_message.add().bb = 3 - m2.repeated_nested_message.add().bb = 1 - m2.repeated_nested_message.add().bb = 2 - m2.repeated_nested_message.add().bb = 3 - - if sys.version_info >= (3,): return # No cmp() in PY3. - - # These comparisons should not raise errors. - _ = m1 < m2 - _ = m1.repeated_nested_message < m2.repeated_nested_message - - # Make sure cmp always works. If it wasn't defined, these would be - # id() comparisons and would all fail. - self.assertEqual(cmp(m1, m2), 0) - self.assertEqual(cmp(m1.repeated_int32, m2.repeated_int32), 0) - self.assertEqual(cmp(m1.repeated_int32, [0, 1, 2]), 0) - self.assertEqual(cmp(m1.repeated_nested_message, - m2.repeated_nested_message), 0) - with self.assertRaises(TypeError): - # Can't compare repeated composite containers to lists. - cmp(m1.repeated_nested_message, m2.repeated_nested_message[:]) - - # TODO(anuraag): Implement extensiondict comparison in C++ and then add test - - def testRepeatedFieldsAreSequences(self, message_module): - m = message_module.TestAllTypes() - self.assertIsInstance(m.repeated_int32, collections.MutableSequence) - self.assertIsInstance(m.repeated_nested_message, - collections.MutableSequence) - - def ensureNestedMessageExists(self, msg, attribute): - """Make sure that a nested message object exists. - - As soon as a nested message attribute is accessed, it will be present in the - _fields dict, without being marked as actually being set. - """ - getattr(msg, attribute) - self.assertFalse(msg.HasField(attribute)) - - def testOneofGetCaseNonexistingField(self, message_module): - m = message_module.TestAllTypes() - self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field') - - def testOneofDefaultValues(self, message_module): - m = message_module.TestAllTypes() - self.assertIs(None, m.WhichOneof('oneof_field')) - self.assertFalse(m.HasField('oneof_uint32')) - - # Oneof is set even when setting it to a default value. - m.oneof_uint32 = 0 - self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) - self.assertTrue(m.HasField('oneof_uint32')) - self.assertFalse(m.HasField('oneof_string')) - - m.oneof_string = "" - self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) - self.assertTrue(m.HasField('oneof_string')) - self.assertFalse(m.HasField('oneof_uint32')) - - def testOneofSemantics(self, message_module): - m = message_module.TestAllTypes() - self.assertIs(None, m.WhichOneof('oneof_field')) - - m.oneof_uint32 = 11 - self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) - self.assertTrue(m.HasField('oneof_uint32')) - - m.oneof_string = u'foo' - self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) - self.assertFalse(m.HasField('oneof_uint32')) - self.assertTrue(m.HasField('oneof_string')) - - # Read nested message accessor without accessing submessage. - m.oneof_nested_message - self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) - self.assertTrue(m.HasField('oneof_string')) - self.assertFalse(m.HasField('oneof_nested_message')) - - # Read accessor of nested message without accessing submessage. - m.oneof_nested_message.bb - self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) - self.assertTrue(m.HasField('oneof_string')) - self.assertFalse(m.HasField('oneof_nested_message')) - - m.oneof_nested_message.bb = 11 - self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) - self.assertFalse(m.HasField('oneof_string')) - self.assertTrue(m.HasField('oneof_nested_message')) - - m.oneof_bytes = b'bb' - self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) - self.assertFalse(m.HasField('oneof_nested_message')) - self.assertTrue(m.HasField('oneof_bytes')) - - def testOneofCompositeFieldReadAccess(self, message_module): - m = message_module.TestAllTypes() - m.oneof_uint32 = 11 - - self.ensureNestedMessageExists(m, 'oneof_nested_message') - self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) - self.assertEqual(11, m.oneof_uint32) - - def testOneofWhichOneof(self, message_module): - m = message_module.TestAllTypes() - self.assertIs(None, m.WhichOneof('oneof_field')) - if message_module is unittest_pb2: - self.assertFalse(m.HasField('oneof_field')) - - m.oneof_uint32 = 11 - self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) - if message_module is unittest_pb2: - self.assertTrue(m.HasField('oneof_field')) - - m.oneof_bytes = b'bb' - self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) - - m.ClearField('oneof_bytes') - self.assertIs(None, m.WhichOneof('oneof_field')) - if message_module is unittest_pb2: - self.assertFalse(m.HasField('oneof_field')) - - def testOneofClearField(self, message_module): - m = message_module.TestAllTypes() - m.oneof_uint32 = 11 - m.ClearField('oneof_field') - if message_module is unittest_pb2: - self.assertFalse(m.HasField('oneof_field')) - self.assertFalse(m.HasField('oneof_uint32')) - self.assertIs(None, m.WhichOneof('oneof_field')) - - def testOneofClearSetField(self, message_module): - m = message_module.TestAllTypes() - m.oneof_uint32 = 11 - m.ClearField('oneof_uint32') - if message_module is unittest_pb2: - self.assertFalse(m.HasField('oneof_field')) - self.assertFalse(m.HasField('oneof_uint32')) - self.assertIs(None, m.WhichOneof('oneof_field')) - - def testOneofClearUnsetField(self, message_module): - m = message_module.TestAllTypes() - m.oneof_uint32 = 11 - self.ensureNestedMessageExists(m, 'oneof_nested_message') - m.ClearField('oneof_nested_message') - self.assertEqual(11, m.oneof_uint32) - if message_module is unittest_pb2: - self.assertTrue(m.HasField('oneof_field')) - self.assertTrue(m.HasField('oneof_uint32')) - self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) - - def testOneofDeserialize(self, message_module): - m = message_module.TestAllTypes() - m.oneof_uint32 = 11 - m2 = message_module.TestAllTypes() - m2.ParseFromString(m.SerializeToString()) - self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) - - def testOneofCopyFrom(self, message_module): - m = message_module.TestAllTypes() - m.oneof_uint32 = 11 - m2 = message_module.TestAllTypes() - m2.CopyFrom(m) - self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) - - def testOneofNestedMergeFrom(self, message_module): - m = message_module.NestedTestAllTypes() - m.payload.oneof_uint32 = 11 - m2 = message_module.NestedTestAllTypes() - m2.payload.oneof_bytes = b'bb' - m2.child.payload.oneof_bytes = b'bb' - m2.MergeFrom(m) - self.assertEqual('oneof_uint32', m2.payload.WhichOneof('oneof_field')) - self.assertEqual('oneof_bytes', m2.child.payload.WhichOneof('oneof_field')) - - def testOneofMessageMergeFrom(self, message_module): - m = message_module.NestedTestAllTypes() - m.payload.oneof_nested_message.bb = 11 - m.child.payload.oneof_nested_message.bb = 12 - m2 = message_module.NestedTestAllTypes() - m2.payload.oneof_uint32 = 13 - m2.MergeFrom(m) - self.assertEqual('oneof_nested_message', - m2.payload.WhichOneof('oneof_field')) - self.assertEqual('oneof_nested_message', - m2.child.payload.WhichOneof('oneof_field')) - - def testOneofNestedMessageInit(self, message_module): - m = message_module.TestAllTypes( - oneof_nested_message=message_module.TestAllTypes.NestedMessage()) - self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) - - def testOneofClear(self, message_module): - m = message_module.TestAllTypes() - m.oneof_uint32 = 11 - m.Clear() - self.assertIsNone(m.WhichOneof('oneof_field')) - m.oneof_bytes = b'bb' - self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) - - def testAssignByteStringToUnicodeField(self, message_module): - """Assigning a byte string to a string field should result - in the value being converted to a Unicode string.""" - m = message_module.TestAllTypes() - m.optional_string = str('') - self.assertIsInstance(m.optional_string, six.text_type) - - def testLongValuedSlice(self, message_module): - """It should be possible to use long-valued indicies in slices - - This didn't used to work in the v2 C++ implementation. - """ - m = message_module.TestAllTypes() - - # Repeated scalar - m.repeated_int32.append(1) - sl = m.repeated_int32[long(0):long(len(m.repeated_int32))] - self.assertEqual(len(m.repeated_int32), len(sl)) - - # Repeated composite - m.repeated_nested_message.add().bb = 3 - sl = m.repeated_nested_message[long(0):long(len(m.repeated_nested_message))] - self.assertEqual(len(m.repeated_nested_message), len(sl)) - - def testExtendShouldNotSwallowExceptions(self, message_module): - """This didn't use to work in the v2 C++ implementation.""" - m = message_module.TestAllTypes() - with self.assertRaises(NameError) as _: - m.repeated_int32.extend(a for i in range(10)) # pylint: disable=undefined-variable - with self.assertRaises(NameError) as _: - m.repeated_nested_enum.extend( - a for i in range(10)) # pylint: disable=undefined-variable - - FALSY_VALUES = [None, False, 0, 0.0, b'', u'', bytearray(), [], {}, set()] - - def testExtendInt32WithNothing(self, message_module): - """Test no-ops extending repeated int32 fields.""" - m = message_module.TestAllTypes() - self.assertSequenceEqual([], m.repeated_int32) - - # TODO(ptucker): Deprecate this behavior. b/18413862 - for falsy_value in MessageTest.FALSY_VALUES: - m.repeated_int32.extend(falsy_value) - self.assertSequenceEqual([], m.repeated_int32) - - m.repeated_int32.extend([]) - self.assertSequenceEqual([], m.repeated_int32) - - def testExtendFloatWithNothing(self, message_module): - """Test no-ops extending repeated float fields.""" - m = message_module.TestAllTypes() - self.assertSequenceEqual([], m.repeated_float) - - # TODO(ptucker): Deprecate this behavior. b/18413862 - for falsy_value in MessageTest.FALSY_VALUES: - m.repeated_float.extend(falsy_value) - self.assertSequenceEqual([], m.repeated_float) - - m.repeated_float.extend([]) - self.assertSequenceEqual([], m.repeated_float) - - def testExtendStringWithNothing(self, message_module): - """Test no-ops extending repeated string fields.""" - m = message_module.TestAllTypes() - self.assertSequenceEqual([], m.repeated_string) - - # TODO(ptucker): Deprecate this behavior. b/18413862 - for falsy_value in MessageTest.FALSY_VALUES: - m.repeated_string.extend(falsy_value) - self.assertSequenceEqual([], m.repeated_string) - - m.repeated_string.extend([]) - self.assertSequenceEqual([], m.repeated_string) - - def testExtendInt32WithPythonList(self, message_module): - """Test extending repeated int32 fields with python lists.""" - m = message_module.TestAllTypes() - self.assertSequenceEqual([], m.repeated_int32) - m.repeated_int32.extend([0]) - self.assertSequenceEqual([0], m.repeated_int32) - m.repeated_int32.extend([1, 2]) - self.assertSequenceEqual([0, 1, 2], m.repeated_int32) - m.repeated_int32.extend([3, 4]) - self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32) - - def testExtendFloatWithPythonList(self, message_module): - """Test extending repeated float fields with python lists.""" - m = message_module.TestAllTypes() - self.assertSequenceEqual([], m.repeated_float) - m.repeated_float.extend([0.0]) - self.assertSequenceEqual([0.0], m.repeated_float) - m.repeated_float.extend([1.0, 2.0]) - self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float) - m.repeated_float.extend([3.0, 4.0]) - self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float) - - def testExtendStringWithPythonList(self, message_module): - """Test extending repeated string fields with python lists.""" - m = message_module.TestAllTypes() - self.assertSequenceEqual([], m.repeated_string) - m.repeated_string.extend(['']) - self.assertSequenceEqual([''], m.repeated_string) - m.repeated_string.extend(['11', '22']) - self.assertSequenceEqual(['', '11', '22'], m.repeated_string) - m.repeated_string.extend(['33', '44']) - self.assertSequenceEqual(['', '11', '22', '33', '44'], m.repeated_string) - - def testExtendStringWithString(self, message_module): - """Test extending repeated string fields with characters from a string.""" - m = message_module.TestAllTypes() - self.assertSequenceEqual([], m.repeated_string) - m.repeated_string.extend('abc') - self.assertSequenceEqual(['a', 'b', 'c'], m.repeated_string) - - class TestIterable(object): - """This iterable object mimics the behavior of numpy.array. - - __nonzero__ fails for length > 1, and returns bool(item[0]) for length == 1. - - """ - - def __init__(self, values=None): - self._list = values or [] - - def __nonzero__(self): - size = len(self._list) - if size == 0: - return False - if size == 1: - return bool(self._list[0]) - raise ValueError('Truth value is ambiguous.') - - def __len__(self): - return len(self._list) - - def __iter__(self): - return self._list.__iter__() - - def testExtendInt32WithIterable(self, message_module): - """Test extending repeated int32 fields with iterable.""" - m = message_module.TestAllTypes() - self.assertSequenceEqual([], m.repeated_int32) - m.repeated_int32.extend(MessageTest.TestIterable([])) - self.assertSequenceEqual([], m.repeated_int32) - m.repeated_int32.extend(MessageTest.TestIterable([0])) - self.assertSequenceEqual([0], m.repeated_int32) - m.repeated_int32.extend(MessageTest.TestIterable([1, 2])) - self.assertSequenceEqual([0, 1, 2], m.repeated_int32) - m.repeated_int32.extend(MessageTest.TestIterable([3, 4])) - self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32) - - def testExtendFloatWithIterable(self, message_module): - """Test extending repeated float fields with iterable.""" - m = message_module.TestAllTypes() - self.assertSequenceEqual([], m.repeated_float) - m.repeated_float.extend(MessageTest.TestIterable([])) - self.assertSequenceEqual([], m.repeated_float) - m.repeated_float.extend(MessageTest.TestIterable([0.0])) - self.assertSequenceEqual([0.0], m.repeated_float) - m.repeated_float.extend(MessageTest.TestIterable([1.0, 2.0])) - self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float) - m.repeated_float.extend(MessageTest.TestIterable([3.0, 4.0])) - self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float) - - def testExtendStringWithIterable(self, message_module): - """Test extending repeated string fields with iterable.""" - m = message_module.TestAllTypes() - self.assertSequenceEqual([], m.repeated_string) - m.repeated_string.extend(MessageTest.TestIterable([])) - self.assertSequenceEqual([], m.repeated_string) - m.repeated_string.extend(MessageTest.TestIterable([''])) - self.assertSequenceEqual([''], m.repeated_string) - m.repeated_string.extend(MessageTest.TestIterable(['1', '2'])) - self.assertSequenceEqual(['', '1', '2'], m.repeated_string) - m.repeated_string.extend(MessageTest.TestIterable(['3', '4'])) - self.assertSequenceEqual(['', '1', '2', '3', '4'], m.repeated_string) - - def testPickleRepeatedScalarContainer(self, message_module): - # TODO(tibell): The pure-Python implementation support pickling of - # scalar containers in *some* cases. For now the cpp2 version - # throws an exception to avoid a segfault. Investigate if we - # want to support pickling of these fields. - # - # For more information see: https://b2.corp.google.com/u/0/issues/18677897 - if (api_implementation.Type() != 'cpp' or - api_implementation.Version() == 2): - return - m = message_module.TestAllTypes() - with self.assertRaises(pickle.PickleError) as _: - pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL) - - def testSortEmptyRepeatedCompositeContainer(self, message_module): - """Exercise a scenario that has led to segfaults in the past. - """ - m = message_module.TestAllTypes() - m.repeated_nested_message.sort() - - def testHasFieldOnRepeatedField(self, message_module): - """Using HasField on a repeated field should raise an exception. - """ - m = message_module.TestAllTypes() - with self.assertRaises(ValueError) as _: - m.HasField('repeated_int32') - - def testRepeatedScalarFieldPop(self, message_module): - m = message_module.TestAllTypes() - with self.assertRaises(IndexError) as _: - m.repeated_int32.pop() - m.repeated_int32.extend(range(5)) - self.assertEqual(4, m.repeated_int32.pop()) - self.assertEqual(0, m.repeated_int32.pop(0)) - self.assertEqual(2, m.repeated_int32.pop(1)) - self.assertEqual([1, 3], m.repeated_int32) - - def testRepeatedCompositeFieldPop(self, message_module): - m = message_module.TestAllTypes() - with self.assertRaises(IndexError) as _: - m.repeated_nested_message.pop() - for i in range(5): - n = m.repeated_nested_message.add() - n.bb = i - self.assertEqual(4, m.repeated_nested_message.pop().bb) - self.assertEqual(0, m.repeated_nested_message.pop(0).bb) - self.assertEqual(2, m.repeated_nested_message.pop(1).bb) - self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message]) - - -# Class to test proto2-only features (required, extensions, etc.) -class Proto2Test(unittest.TestCase): - - def testFieldPresence(self): - message = unittest_pb2.TestAllTypes() - - self.assertFalse(message.HasField("optional_int32")) - self.assertFalse(message.HasField("optional_bool")) - self.assertFalse(message.HasField("optional_nested_message")) - - with self.assertRaises(ValueError): - message.HasField("field_doesnt_exist") - - with self.assertRaises(ValueError): - message.HasField("repeated_int32") - with self.assertRaises(ValueError): - message.HasField("repeated_nested_message") - - self.assertEqual(0, message.optional_int32) - self.assertEqual(False, message.optional_bool) - self.assertEqual(0, message.optional_nested_message.bb) - - # Fields are set even when setting the values to default values. - message.optional_int32 = 0 - message.optional_bool = False - message.optional_nested_message.bb = 0 - self.assertTrue(message.HasField("optional_int32")) - self.assertTrue(message.HasField("optional_bool")) - self.assertTrue(message.HasField("optional_nested_message")) - - # Set the fields to non-default values. - message.optional_int32 = 5 - message.optional_bool = True - message.optional_nested_message.bb = 15 - - self.assertTrue(message.HasField("optional_int32")) - self.assertTrue(message.HasField("optional_bool")) - self.assertTrue(message.HasField("optional_nested_message")) - - # Clearing the fields unsets them and resets their value to default. - message.ClearField("optional_int32") - message.ClearField("optional_bool") - message.ClearField("optional_nested_message") - - self.assertFalse(message.HasField("optional_int32")) - self.assertFalse(message.HasField("optional_bool")) - self.assertFalse(message.HasField("optional_nested_message")) - self.assertEqual(0, message.optional_int32) - self.assertEqual(False, message.optional_bool) - self.assertEqual(0, message.optional_nested_message.bb) - - # TODO(tibell): The C++ implementations actually allows assignment - # of unknown enum values to *scalar* fields (but not repeated - # fields). Once checked enum fields becomes the default in the - # Python implementation, the C++ implementation should follow suit. - def testAssignInvalidEnum(self): - """It should not be possible to assign an invalid enum number to an - enum field.""" - m = unittest_pb2.TestAllTypes() - - with self.assertRaises(ValueError) as _: - m.optional_nested_enum = 1234567 - self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567) - - def testGoldenExtensions(self): - golden_data = test_util.GoldenFileData('golden_message') - golden_message = unittest_pb2.TestAllExtensions() - golden_message.ParseFromString(golden_data) - all_set = unittest_pb2.TestAllExtensions() - test_util.SetAllExtensions(all_set) - self.assertEqual(all_set, golden_message) - self.assertEqual(golden_data, golden_message.SerializeToString()) - golden_copy = copy.deepcopy(golden_message) - self.assertEqual(golden_data, golden_copy.SerializeToString()) - - def testGoldenPackedExtensions(self): - golden_data = test_util.GoldenFileData('golden_packed_fields_message') - golden_message = unittest_pb2.TestPackedExtensions() - golden_message.ParseFromString(golden_data) - all_set = unittest_pb2.TestPackedExtensions() - test_util.SetAllPackedExtensions(all_set) - self.assertEqual(all_set, golden_message) - self.assertEqual(golden_data, all_set.SerializeToString()) - golden_copy = copy.deepcopy(golden_message) - self.assertEqual(golden_data, golden_copy.SerializeToString()) - - def testPickleIncompleteProto(self): - golden_message = unittest_pb2.TestRequired(a=1) - pickled_message = pickle.dumps(golden_message) - - unpickled_message = pickle.loads(pickled_message) - self.assertEqual(unpickled_message, golden_message) - self.assertEqual(unpickled_message.a, 1) - # This is still an incomplete proto - so serializing should fail - self.assertRaises(message.EncodeError, unpickled_message.SerializeToString) - - - # TODO(haberman): this isn't really a proto2-specific test except that this - # message has a required field in it. Should probably be factored out so - # that we can test the other parts with proto3. - def testParsingMerge(self): - """Check the merge behavior when a required or optional field appears - multiple times in the input.""" - messages = [ - unittest_pb2.TestAllTypes(), - unittest_pb2.TestAllTypes(), - unittest_pb2.TestAllTypes() ] - messages[0].optional_int32 = 1 - messages[1].optional_int64 = 2 - messages[2].optional_int32 = 3 - messages[2].optional_string = 'hello' - - merged_message = unittest_pb2.TestAllTypes() - merged_message.optional_int32 = 3 - merged_message.optional_int64 = 2 - merged_message.optional_string = 'hello' - - generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator() - generator.field1.extend(messages) - generator.field2.extend(messages) - generator.field3.extend(messages) - generator.ext1.extend(messages) - generator.ext2.extend(messages) - generator.group1.add().field1.MergeFrom(messages[0]) - generator.group1.add().field1.MergeFrom(messages[1]) - generator.group1.add().field1.MergeFrom(messages[2]) - generator.group2.add().field1.MergeFrom(messages[0]) - generator.group2.add().field1.MergeFrom(messages[1]) - generator.group2.add().field1.MergeFrom(messages[2]) - - data = generator.SerializeToString() - parsing_merge = unittest_pb2.TestParsingMerge() - parsing_merge.ParseFromString(data) - - # Required and optional fields should be merged. - self.assertEqual(parsing_merge.required_all_types, merged_message) - self.assertEqual(parsing_merge.optional_all_types, merged_message) - self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types, - merged_message) - self.assertEqual(parsing_merge.Extensions[ - unittest_pb2.TestParsingMerge.optional_ext], - merged_message) - - # Repeated fields should not be merged. - self.assertEqual(len(parsing_merge.repeated_all_types), 3) - self.assertEqual(len(parsing_merge.repeatedgroup), 3) - self.assertEqual(len(parsing_merge.Extensions[ - unittest_pb2.TestParsingMerge.repeated_ext]), 3) - - def testPythonicInit(self): - message = unittest_pb2.TestAllTypes( - optional_int32=100, - optional_fixed32=200, - optional_float=300.5, - optional_bytes=b'x', - optionalgroup={'a': 400}, - optional_nested_message={'bb': 500}, - optional_nested_enum='BAZ', - repeatedgroup=[{'a': 600}, - {'a': 700}], - repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR], - default_int32=800, - oneof_string='y') - self.assertIsInstance(message, unittest_pb2.TestAllTypes) - self.assertEqual(100, message.optional_int32) - self.assertEqual(200, message.optional_fixed32) - self.assertEqual(300.5, message.optional_float) - self.assertEqual(b'x', message.optional_bytes) - self.assertEqual(400, message.optionalgroup.a) - self.assertIsInstance(message.optional_nested_message, unittest_pb2.TestAllTypes.NestedMessage) - self.assertEqual(500, message.optional_nested_message.bb) - self.assertEqual(unittest_pb2.TestAllTypes.BAZ, - message.optional_nested_enum) - self.assertEqual(2, len(message.repeatedgroup)) - self.assertEqual(600, message.repeatedgroup[0].a) - self.assertEqual(700, message.repeatedgroup[1].a) - self.assertEqual(2, len(message.repeated_nested_enum)) - self.assertEqual(unittest_pb2.TestAllTypes.FOO, - message.repeated_nested_enum[0]) - self.assertEqual(unittest_pb2.TestAllTypes.BAR, - message.repeated_nested_enum[1]) - self.assertEqual(800, message.default_int32) - self.assertEqual('y', message.oneof_string) - self.assertFalse(message.HasField('optional_int64')) - self.assertEqual(0, len(message.repeated_float)) - self.assertEqual(42, message.default_int64) - - message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ') - self.assertEqual(unittest_pb2.TestAllTypes.BAZ, - message.optional_nested_enum) - - with self.assertRaises(ValueError): - unittest_pb2.TestAllTypes( - optional_nested_message={'INVALID_NESTED_FIELD': 17}) - - with self.assertRaises(TypeError): - unittest_pb2.TestAllTypes( - optional_nested_message={'bb': 'INVALID_VALUE_TYPE'}) - - with self.assertRaises(ValueError): - unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL') - - with self.assertRaises(ValueError): - unittest_pb2.TestAllTypes(repeated_nested_enum='FOO') - - - -# Class to test proto3-only features/behavior (updated field presence & enums) -class Proto3Test(unittest.TestCase): - - # Utility method for comparing equality with a map. - def assertMapIterEquals(self, map_iter, dict_value): - # Avoid mutating caller's copy. - dict_value = dict(dict_value) - - for k, v in map_iter: - self.assertEqual(v, dict_value[k]) - del dict_value[k] - - self.assertEqual({}, dict_value) - - def testFieldPresence(self): - message = unittest_proto3_arena_pb2.TestAllTypes() - - # We can't test presence of non-repeated, non-submessage fields. - with self.assertRaises(ValueError): - message.HasField('optional_int32') - with self.assertRaises(ValueError): - message.HasField('optional_float') - with self.assertRaises(ValueError): - message.HasField('optional_string') - with self.assertRaises(ValueError): - message.HasField('optional_bool') - - # But we can still test presence of submessage fields. - self.assertFalse(message.HasField('optional_nested_message')) - - # As with proto2, we can't test presence of fields that don't exist, or - # repeated fields. - with self.assertRaises(ValueError): - message.HasField('field_doesnt_exist') - - with self.assertRaises(ValueError): - message.HasField('repeated_int32') - with self.assertRaises(ValueError): - message.HasField('repeated_nested_message') - - # Fields should default to their type-specific default. - self.assertEqual(0, message.optional_int32) - self.assertEqual(0, message.optional_float) - self.assertEqual('', message.optional_string) - self.assertEqual(False, message.optional_bool) - self.assertEqual(0, message.optional_nested_message.bb) - - # Setting a submessage should still return proper presence information. - message.optional_nested_message.bb = 0 - self.assertTrue(message.HasField('optional_nested_message')) - - # Set the fields to non-default values. - message.optional_int32 = 5 - message.optional_float = 1.1 - message.optional_string = 'abc' - message.optional_bool = True - message.optional_nested_message.bb = 15 - - # Clearing the fields unsets them and resets their value to default. - message.ClearField('optional_int32') - message.ClearField('optional_float') - message.ClearField('optional_string') - message.ClearField('optional_bool') - message.ClearField('optional_nested_message') - - self.assertEqual(0, message.optional_int32) - self.assertEqual(0, message.optional_float) - self.assertEqual('', message.optional_string) - self.assertEqual(False, message.optional_bool) - self.assertEqual(0, message.optional_nested_message.bb) - - def testAssignUnknownEnum(self): - """Assigning an unknown enum value is allowed and preserves the value.""" - m = unittest_proto3_arena_pb2.TestAllTypes() - - m.optional_nested_enum = 1234567 - self.assertEqual(1234567, m.optional_nested_enum) - m.repeated_nested_enum.append(22334455) - self.assertEqual(22334455, m.repeated_nested_enum[0]) - # Assignment is a different code path than append for the C++ impl. - m.repeated_nested_enum[0] = 7654321 - self.assertEqual(7654321, m.repeated_nested_enum[0]) - serialized = m.SerializeToString() - - m2 = unittest_proto3_arena_pb2.TestAllTypes() - m2.ParseFromString(serialized) - self.assertEqual(1234567, m2.optional_nested_enum) - self.assertEqual(7654321, m2.repeated_nested_enum[0]) - - # Map isn't really a proto3-only feature. But there is no proto2 equivalent - # of google/protobuf/map_unittest.proto right now, so it's not easy to - # test both with the same test like we do for the other proto2/proto3 tests. - # (google/protobuf/map_protobuf_unittest.proto is very different in the set - # of messages and fields it contains). - def testScalarMapDefaults(self): - msg = map_unittest_pb2.TestMap() - - # Scalars start out unset. - self.assertFalse(-123 in msg.map_int32_int32) - self.assertFalse(-2**33 in msg.map_int64_int64) - self.assertFalse(123 in msg.map_uint32_uint32) - self.assertFalse(2**33 in msg.map_uint64_uint64) - self.assertFalse(123 in msg.map_int32_double) - self.assertFalse(False in msg.map_bool_bool) - self.assertFalse('abc' in msg.map_string_string) - self.assertFalse(111 in msg.map_int32_bytes) - self.assertFalse(888 in msg.map_int32_enum) - - # Accessing an unset key returns the default. - self.assertEqual(0, msg.map_int32_int32[-123]) - self.assertEqual(0, msg.map_int64_int64[-2**33]) - self.assertEqual(0, msg.map_uint32_uint32[123]) - self.assertEqual(0, msg.map_uint64_uint64[2**33]) - self.assertEqual(0.0, msg.map_int32_double[123]) - self.assertTrue(isinstance(msg.map_int32_double[123], float)) - self.assertEqual(False, msg.map_bool_bool[False]) - self.assertTrue(isinstance(msg.map_bool_bool[False], bool)) - self.assertEqual('', msg.map_string_string['abc']) - self.assertEqual(b'', msg.map_int32_bytes[111]) - self.assertEqual(0, msg.map_int32_enum[888]) - - # It also sets the value in the map - self.assertTrue(-123 in msg.map_int32_int32) - self.assertTrue(-2**33 in msg.map_int64_int64) - self.assertTrue(123 in msg.map_uint32_uint32) - self.assertTrue(2**33 in msg.map_uint64_uint64) - self.assertTrue(123 in msg.map_int32_double) - self.assertTrue(False in msg.map_bool_bool) - self.assertTrue('abc' in msg.map_string_string) - self.assertTrue(111 in msg.map_int32_bytes) - self.assertTrue(888 in msg.map_int32_enum) - - self.assertIsInstance(msg.map_string_string['abc'], six.text_type) - - # Accessing an unset key still throws TypeError if the type of the key - # is incorrect. - with self.assertRaises(TypeError): - msg.map_string_string[123] - - with self.assertRaises(TypeError): - 123 in msg.map_string_string - - def testMapGet(self): - # Need to test that get() properly returns the default, even though the dict - # has defaultdict-like semantics. - msg = map_unittest_pb2.TestMap() - - self.assertIsNone(msg.map_int32_int32.get(5)) - self.assertEqual(10, msg.map_int32_int32.get(5, 10)) - self.assertIsNone(msg.map_int32_int32.get(5)) - - msg.map_int32_int32[5] = 15 - self.assertEqual(15, msg.map_int32_int32.get(5)) - - self.assertIsNone(msg.map_int32_foreign_message.get(5)) - self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10)) - - submsg = msg.map_int32_foreign_message[5] - self.assertIs(submsg, msg.map_int32_foreign_message.get(5)) - - def testScalarMap(self): - msg = map_unittest_pb2.TestMap() - - self.assertEqual(0, len(msg.map_int32_int32)) - self.assertFalse(5 in msg.map_int32_int32) - - msg.map_int32_int32[-123] = -456 - msg.map_int64_int64[-2**33] = -2**34 - msg.map_uint32_uint32[123] = 456 - msg.map_uint64_uint64[2**33] = 2**34 - msg.map_string_string['abc'] = '123' - msg.map_int32_enum[888] = 2 - - self.assertEqual([], msg.FindInitializationErrors()) - - self.assertEqual(1, len(msg.map_string_string)) - - # Bad key. - with self.assertRaises(TypeError): - msg.map_string_string[123] = '123' - - # Verify that trying to assign a bad key doesn't actually add a member to - # the map. - self.assertEqual(1, len(msg.map_string_string)) - - # Bad value. - with self.assertRaises(TypeError): - msg.map_string_string['123'] = 123 - - serialized = msg.SerializeToString() - msg2 = map_unittest_pb2.TestMap() - msg2.ParseFromString(serialized) - - # Bad key. - with self.assertRaises(TypeError): - msg2.map_string_string[123] = '123' - - # Bad value. - with self.assertRaises(TypeError): - msg2.map_string_string['123'] = 123 - - self.assertEqual(-456, msg2.map_int32_int32[-123]) - self.assertEqual(-2**34, msg2.map_int64_int64[-2**33]) - self.assertEqual(456, msg2.map_uint32_uint32[123]) - self.assertEqual(2**34, msg2.map_uint64_uint64[2**33]) - self.assertEqual('123', msg2.map_string_string['abc']) - self.assertEqual(2, msg2.map_int32_enum[888]) - - def testStringUnicodeConversionInMap(self): - msg = map_unittest_pb2.TestMap() - - unicode_obj = u'\u1234' - bytes_obj = unicode_obj.encode('utf8') - - msg.map_string_string[bytes_obj] = bytes_obj - - (key, value) = list(msg.map_string_string.items())[0] - - self.assertEqual(key, unicode_obj) - self.assertEqual(value, unicode_obj) - - self.assertIsInstance(key, six.text_type) - self.assertIsInstance(value, six.text_type) - - def testMessageMap(self): - msg = map_unittest_pb2.TestMap() - - self.assertEqual(0, len(msg.map_int32_foreign_message)) - self.assertFalse(5 in msg.map_int32_foreign_message) - - msg.map_int32_foreign_message[123] - # get_or_create() is an alias for getitem. - msg.map_int32_foreign_message.get_or_create(-456) - - self.assertEqual(2, len(msg.map_int32_foreign_message)) - self.assertIn(123, msg.map_int32_foreign_message) - self.assertIn(-456, msg.map_int32_foreign_message) - self.assertEqual(2, len(msg.map_int32_foreign_message)) - - # Bad key. - with self.assertRaises(TypeError): - msg.map_int32_foreign_message['123'] - - # Can't assign directly to submessage. - with self.assertRaises(ValueError): - msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123] - - # Verify that trying to assign a bad key doesn't actually add a member to - # the map. - self.assertEqual(2, len(msg.map_int32_foreign_message)) - - serialized = msg.SerializeToString() - msg2 = map_unittest_pb2.TestMap() - msg2.ParseFromString(serialized) - - self.assertEqual(2, len(msg2.map_int32_foreign_message)) - self.assertIn(123, msg2.map_int32_foreign_message) - self.assertIn(-456, msg2.map_int32_foreign_message) - self.assertEqual(2, len(msg2.map_int32_foreign_message)) - - def testMergeFrom(self): - msg = map_unittest_pb2.TestMap() - msg.map_int32_int32[12] = 34 - msg.map_int32_int32[56] = 78 - msg.map_int64_int64[22] = 33 - msg.map_int32_foreign_message[111].c = 5 - msg.map_int32_foreign_message[222].c = 10 - - msg2 = map_unittest_pb2.TestMap() - msg2.map_int32_int32[12] = 55 - msg2.map_int64_int64[88] = 99 - msg2.map_int32_foreign_message[222].c = 15 - msg2.map_int32_foreign_message[222].d = 20 - old_map_value = msg2.map_int32_foreign_message[222] - - msg2.MergeFrom(msg) - - self.assertEqual(34, msg2.map_int32_int32[12]) - self.assertEqual(78, msg2.map_int32_int32[56]) - self.assertEqual(33, msg2.map_int64_int64[22]) - self.assertEqual(99, msg2.map_int64_int64[88]) - self.assertEqual(5, msg2.map_int32_foreign_message[111].c) - self.assertEqual(10, msg2.map_int32_foreign_message[222].c) - self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d')) - self.assertEqual(15, old_map_value.c) - - # Verify that there is only one entry per key, even though the MergeFrom - # may have internally created multiple entries for a single key in the - # list representation. - as_dict = {} - for key in msg2.map_int32_foreign_message: - self.assertFalse(key in as_dict) - as_dict[key] = msg2.map_int32_foreign_message[key].c - - self.assertEqual({111: 5, 222: 10}, as_dict) - - # Special case: test that delete of item really removes the item, even if - # there might have physically been duplicate keys due to the previous merge. - # This is only a special case for the C++ implementation which stores the - # map as an array. - del msg2.map_int32_int32[12] - self.assertFalse(12 in msg2.map_int32_int32) - - del msg2.map_int32_foreign_message[222] - self.assertFalse(222 in msg2.map_int32_foreign_message) - - def testMergeFromBadType(self): - msg = map_unittest_pb2.TestMap() - with self.assertRaisesRegexp( - TypeError, - r'Parameter to MergeFrom\(\) must be instance of same class: expected ' - r'.*TestMap got int\.'): - msg.MergeFrom(1) - - def testCopyFromBadType(self): - msg = map_unittest_pb2.TestMap() - with self.assertRaisesRegexp( - TypeError, - r'Parameter to [A-Za-z]*From\(\) must be instance of same class: ' - r'expected .*TestMap got int\.'): - msg.CopyFrom(1) - - def testIntegerMapWithLongs(self): - msg = map_unittest_pb2.TestMap() - msg.map_int32_int32[long(-123)] = long(-456) - msg.map_int64_int64[long(-2**33)] = long(-2**34) - msg.map_uint32_uint32[long(123)] = long(456) - msg.map_uint64_uint64[long(2**33)] = long(2**34) - - serialized = msg.SerializeToString() - msg2 = map_unittest_pb2.TestMap() - msg2.ParseFromString(serialized) - - self.assertEqual(-456, msg2.map_int32_int32[-123]) - self.assertEqual(-2**34, msg2.map_int64_int64[-2**33]) - self.assertEqual(456, msg2.map_uint32_uint32[123]) - self.assertEqual(2**34, msg2.map_uint64_uint64[2**33]) - - def testMapAssignmentCausesPresence(self): - msg = map_unittest_pb2.TestMapSubmessage() - msg.test_map.map_int32_int32[123] = 456 - - serialized = msg.SerializeToString() - msg2 = map_unittest_pb2.TestMapSubmessage() - msg2.ParseFromString(serialized) - - self.assertEqual(msg, msg2) - - # Now test that various mutations of the map properly invalidate the - # cached size of the submessage. - msg.test_map.map_int32_int32[888] = 999 - serialized = msg.SerializeToString() - msg2.ParseFromString(serialized) - self.assertEqual(msg, msg2) - - msg.test_map.map_int32_int32.clear() - serialized = msg.SerializeToString() - msg2.ParseFromString(serialized) - self.assertEqual(msg, msg2) - - def testMapAssignmentCausesPresenceForSubmessages(self): - msg = map_unittest_pb2.TestMapSubmessage() - msg.test_map.map_int32_foreign_message[123].c = 5 - - serialized = msg.SerializeToString() - msg2 = map_unittest_pb2.TestMapSubmessage() - msg2.ParseFromString(serialized) - - self.assertEqual(msg, msg2) - - # Now test that various mutations of the map properly invalidate the - # cached size of the submessage. - msg.test_map.map_int32_foreign_message[888].c = 7 - serialized = msg.SerializeToString() - msg2.ParseFromString(serialized) - self.assertEqual(msg, msg2) - - msg.test_map.map_int32_foreign_message[888].MergeFrom( - msg.test_map.map_int32_foreign_message[123]) - serialized = msg.SerializeToString() - msg2.ParseFromString(serialized) - self.assertEqual(msg, msg2) - - msg.test_map.map_int32_foreign_message.clear() - serialized = msg.SerializeToString() - msg2.ParseFromString(serialized) - self.assertEqual(msg, msg2) - - def testModifyMapWhileIterating(self): - msg = map_unittest_pb2.TestMap() - - string_string_iter = iter(msg.map_string_string) - int32_foreign_iter = iter(msg.map_int32_foreign_message) - - msg.map_string_string['abc'] = '123' - msg.map_int32_foreign_message[5].c = 5 - - with self.assertRaises(RuntimeError): - for key in string_string_iter: - pass - - with self.assertRaises(RuntimeError): - for key in int32_foreign_iter: - pass - - def testSubmessageMap(self): - msg = map_unittest_pb2.TestMap() - - submsg = msg.map_int32_foreign_message[111] - self.assertIs(submsg, msg.map_int32_foreign_message[111]) - self.assertIsInstance(submsg, unittest_pb2.ForeignMessage) - - submsg.c = 5 - - serialized = msg.SerializeToString() - msg2 = map_unittest_pb2.TestMap() - msg2.ParseFromString(serialized) - - self.assertEqual(5, msg2.map_int32_foreign_message[111].c) - - # Doesn't allow direct submessage assignment. - with self.assertRaises(ValueError): - msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage() - - def testMapIteration(self): - msg = map_unittest_pb2.TestMap() - - for k, v in msg.map_int32_int32.items(): - # Should not be reached. - self.assertTrue(False) - - msg.map_int32_int32[2] = 4 - msg.map_int32_int32[3] = 6 - msg.map_int32_int32[4] = 8 - self.assertEqual(3, len(msg.map_int32_int32)) - - matching_dict = {2: 4, 3: 6, 4: 8} - self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict) - - def testMapItems(self): - # Map items used to have strange behaviors when use c extension. Because - # [] may reorder the map and invalidate any exsting iterators. - # TODO(jieluo): Check if [] reordering the map is a bug or intended - # behavior. - msg = map_unittest_pb2.TestMap() - msg.map_string_string['local_init_op'] = '' - msg.map_string_string['trainable_variables'] = '' - msg.map_string_string['variables'] = '' - msg.map_string_string['init_op'] = '' - msg.map_string_string['summaries'] = '' - items1 = msg.map_string_string.items() - items2 = msg.map_string_string.items() - self.assertEqual(items1, items2) - - def testMapIterationClearMessage(self): - # Iterator needs to work even if message and map are deleted. - msg = map_unittest_pb2.TestMap() - - msg.map_int32_int32[2] = 4 - msg.map_int32_int32[3] = 6 - msg.map_int32_int32[4] = 8 - - it = msg.map_int32_int32.items() - del msg - - matching_dict = {2: 4, 3: 6, 4: 8} - self.assertMapIterEquals(it, matching_dict) - - def testMapConstruction(self): - msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4}) - self.assertEqual(2, msg.map_int32_int32[1]) - self.assertEqual(4, msg.map_int32_int32[3]) - - msg = map_unittest_pb2.TestMap( - map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)}) - self.assertEqual(5, msg.map_int32_foreign_message[3].c) - - def testMapValidAfterFieldCleared(self): - # Map needs to work even if field is cleared. - # For the C++ implementation this tests the correctness of - # ScalarMapContainer::Release() - msg = map_unittest_pb2.TestMap() - int32_map = msg.map_int32_int32 - - int32_map[2] = 4 - int32_map[3] = 6 - int32_map[4] = 8 - - msg.ClearField('map_int32_int32') - self.assertEqual(b'', msg.SerializeToString()) - matching_dict = {2: 4, 3: 6, 4: 8} - self.assertMapIterEquals(int32_map.items(), matching_dict) - - def testMessageMapValidAfterFieldCleared(self): - # Map needs to work even if field is cleared. - # For the C++ implementation this tests the correctness of - # ScalarMapContainer::Release() - msg = map_unittest_pb2.TestMap() - int32_foreign_message = msg.map_int32_foreign_message - - int32_foreign_message[2].c = 5 - - msg.ClearField('map_int32_foreign_message') - self.assertEqual(b'', msg.SerializeToString()) - self.assertTrue(2 in int32_foreign_message.keys()) - - def testMapIterInvalidatedByClearField(self): - # Map iterator is invalidated when field is cleared. - # But this case does need to not crash the interpreter. - # For the C++ implementation this tests the correctness of - # ScalarMapContainer::Release() - msg = map_unittest_pb2.TestMap() - - it = iter(msg.map_int32_int32) - - msg.ClearField('map_int32_int32') - with self.assertRaises(RuntimeError): - for _ in it: - pass - - it = iter(msg.map_int32_foreign_message) - msg.ClearField('map_int32_foreign_message') - with self.assertRaises(RuntimeError): - for _ in it: - pass - - def testMapDelete(self): - msg = map_unittest_pb2.TestMap() - - self.assertEqual(0, len(msg.map_int32_int32)) - - msg.map_int32_int32[4] = 6 - self.assertEqual(1, len(msg.map_int32_int32)) - - with self.assertRaises(KeyError): - del msg.map_int32_int32[88] - - del msg.map_int32_int32[4] - self.assertEqual(0, len(msg.map_int32_int32)) - - def testMapsAreMapping(self): - msg = map_unittest_pb2.TestMap() - self.assertIsInstance(msg.map_int32_int32, collections.Mapping) - self.assertIsInstance(msg.map_int32_int32, collections.MutableMapping) - self.assertIsInstance(msg.map_int32_foreign_message, collections.Mapping) - self.assertIsInstance(msg.map_int32_foreign_message, - collections.MutableMapping) - - def testMapFindInitializationErrorsSmokeTest(self): - msg = map_unittest_pb2.TestMap() - msg.map_string_string['abc'] = '123' - msg.map_int32_int32[35] = 64 - msg.map_string_foreign_message['foo'].c = 5 - self.assertEqual(0, len(msg.FindInitializationErrors())) - - - -class ValidTypeNamesTest(unittest.TestCase): - - def assertImportFromName(self, msg, base_name): - # Parse <type 'module.class_name'> to extra 'some.name' as a string. - tp_name = str(type(msg)).split("'")[1] - valid_names = ('Repeated%sContainer' % base_name, - 'Repeated%sFieldContainer' % base_name) - self.assertTrue(any(tp_name.endswith(v) for v in valid_names), - '%r does end with any of %r' % (tp_name, valid_names)) - - parts = tp_name.split('.') - class_name = parts[-1] - module_name = '.'.join(parts[:-1]) - __import__(module_name, fromlist=[class_name]) - - def testTypeNamesCanBeImported(self): - # If import doesn't work, pickling won't work either. - pb = unittest_pb2.TestAllTypes() - self.assertImportFromName(pb.repeated_int32, 'Scalar') - self.assertImportFromName(pb.repeated_nested_message, 'Composite') - -class PackedFieldTest(unittest.TestCase): - - def setMessage(self, message): - message.repeated_int32.append(1) - message.repeated_int64.append(1) - message.repeated_uint32.append(1) - message.repeated_uint64.append(1) - message.repeated_sint32.append(1) - message.repeated_sint64.append(1) - message.repeated_fixed32.append(1) - message.repeated_fixed64.append(1) - message.repeated_sfixed32.append(1) - message.repeated_sfixed64.append(1) - message.repeated_float.append(1.0) - message.repeated_double.append(1.0) - message.repeated_bool.append(True) - message.repeated_nested_enum.append(1) - - def testPackedFields(self): - message = packed_field_test_pb2.TestPackedTypes() - self.setMessage(message) - golden_data = (b'\x0A\x01\x01' - b'\x12\x01\x01' - b'\x1A\x01\x01' - b'\x22\x01\x01' - b'\x2A\x01\x02' - b'\x32\x01\x02' - b'\x3A\x04\x01\x00\x00\x00' - b'\x42\x08\x01\x00\x00\x00\x00\x00\x00\x00' - b'\x4A\x04\x01\x00\x00\x00' - b'\x52\x08\x01\x00\x00\x00\x00\x00\x00\x00' - b'\x5A\x04\x00\x00\x80\x3f' - b'\x62\x08\x00\x00\x00\x00\x00\x00\xf0\x3f' - b'\x6A\x01\x01' - b'\x72\x01\x01') - self.assertEqual(golden_data, message.SerializeToString()) - - def testUnpackedFields(self): - message = packed_field_test_pb2.TestUnpackedTypes() - self.setMessage(message) - golden_data = (b'\x08\x01' - b'\x10\x01' - b'\x18\x01' - b'\x20\x01' - b'\x28\x02' - b'\x30\x02' - b'\x3D\x01\x00\x00\x00' - b'\x41\x01\x00\x00\x00\x00\x00\x00\x00' - b'\x4D\x01\x00\x00\x00' - b'\x51\x01\x00\x00\x00\x00\x00\x00\x00' - b'\x5D\x00\x00\x80\x3f' - b'\x61\x00\x00\x00\x00\x00\x00\xf0\x3f' - b'\x68\x01' - b'\x70\x01') - self.assertEqual(golden_data, message.SerializeToString()) - - -@unittest.skipIf(api_implementation.Type() != 'cpp', - 'explicit tests of the C++ implementation') -class OversizeProtosTest(unittest.TestCase): - - def setUp(self): - self.file_desc = """ - name: "f/f.msg2" - package: "f" - message_type { - name: "msg1" - field { - name: "payload" - number: 1 - label: LABEL_OPTIONAL - type: TYPE_STRING - } - } - message_type { - name: "msg2" - field { - name: "field" - number: 1 - label: LABEL_OPTIONAL - type: TYPE_MESSAGE - type_name: "msg1" - } - } - """ - pool = descriptor_pool.DescriptorPool() - desc = descriptor_pb2.FileDescriptorProto() - text_format.Parse(self.file_desc, desc) - pool.Add(desc) - self.proto_cls = message_factory.MessageFactory(pool).GetPrototype( - pool.FindMessageTypeByName('f.msg2')) - self.p = self.proto_cls() - self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1) - self.p_serialized = self.p.SerializeToString() - - def testAssertOversizeProto(self): - from google.protobuf.pyext._message import SetAllowOversizeProtos - SetAllowOversizeProtos(False) - q = self.proto_cls() - try: - q.ParseFromString(self.p_serialized) - except message.DecodeError as e: - self.assertEqual(str(e), 'Error parsing message') - - def testSucceedOversizeProto(self): - from google.protobuf.pyext._message import SetAllowOversizeProtos - SetAllowOversizeProtos(True) - q = self.proto_cls() - q.ParseFromString(self.p_serialized) - self.assertEqual(self.p.field.payload, q.field.payload) - -if __name__ == '__main__': - unittest.main() |