From 4d7232a5dfdedf38d80bb406f1814b40f638e2f2 Mon Sep 17 00:00:00 2001 From: Ilya Etingof Date: Sun, 29 Oct 2017 16:04:01 +0100 Subject: Pickle protocol fixes (#99) * do not blow up on pickle protocol attributes look up * added Pickle tests * More fixes to pickle protocol support * __slots__ lookup allowed at NoValue * SizedInteger moved from BitString scope to the univ module scope --- CHANGES.rst | 1 + pyasn1/type/base.py | 28 ++++- pyasn1/type/univ.py | 75 +++++++------- tests/type/test_char.py | 16 +++ tests/type/test_univ.py | 254 ++++++++++++++++++++++++++++++++++++++++++++++ tests/type/test_useful.py | 38 +++++++ 6 files changed, 372 insertions(+), 40 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 8d11616..0a71810 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -22,6 +22,7 @@ Revision 0.4.1, released XX-10-2017 opposed to constructed class). - Fixed CER/DER encoders to respect tagged CHOICE when ordering SET components +- Fixed ASN.1 types not to interfere with the Pickle protocol Revision 0.3.7, released 04-10-2017 ----------------------------------- diff --git a/pyasn1/type/base.py b/pyasn1/type/base.py index f22823b..73755cc 100644 --- a/pyasn1/type/base.py +++ b/pyasn1/type/base.py @@ -155,9 +155,31 @@ class NoValue(object): Any operation attempted on the *noValue* object will raise the *PyAsn1Error* exception. """ - skipMethods = ('__getattribute__', '__getattr__', '__setattr__', '__delattr__', - '__class__', '__init__', '__del__', '__new__', '__repr__', - '__qualname__', '__objclass__', 'im_class', '__sizeof__') + skipMethods = set( + ('__slots__', + # attributes + '__getattribute__', + '__getattr__', + '__setattr__', + '__delattr__', + # class instance + '__class__', + '__init__', + '__del__', + '__new__', + '__repr__', + '__qualname__', + '__objclass__', + 'im_class', + '__sizeof__', + # pickle protocol + '__reduce__', + '__reduce_ex__', + '__getnewargs__', + '__getinitargs__', + '__getstate__', + '__setstate__') + ) _instance = None diff --git a/pyasn1/type/univ.py b/pyasn1/type/univ.py index 43c0d2c..1ccbf8e 100644 --- a/pyasn1/type/univ.py +++ b/pyasn1/type/univ.py @@ -272,6 +272,26 @@ class Boolean(Integer): # Optimization for faster codec lookup typeId = Integer.getTypeId() +if sys.version_info[0] < 3: + SizedIntegerBase = long +else: + SizedIntegerBase = int + + +class SizedInteger(SizedIntegerBase): + bitLength = leadingZeroBits = None + + def setBitLength(self, bitLength): + self.bitLength = bitLength + self.leadingZeroBits = max(bitLength - integer.bitLength(self), 0) + return self + + def __len__(self): + if self.bitLength is None: + self.setBitLength(integer.bitLength(self)) + + return self.bitLength + class BitString(base.AbstractSimpleAsn1Item): """Create |ASN.1| schema or value object. @@ -328,25 +348,6 @@ class BitString(base.AbstractSimpleAsn1Item): defaultBinValue = defaultHexValue = noValue - if sys.version_info[0] < 3: - SizedIntegerBase = long - else: - SizedIntegerBase = int - - class SizedInteger(SizedIntegerBase): - bitLength = leadingZeroBits = None - - def setBitLength(self, bitLength): - self.bitLength = bitLength - self.leadingZeroBits = max(bitLength - integer.bitLength(self), 0) - return self - - def __len__(self): - if self.bitLength is None: - self.setBitLength(integer.bitLength(self)) - - return self.bitLength - def __init__(self, value=noValue, **kwargs): if value is noValue: if kwargs: @@ -428,11 +429,11 @@ class BitString(base.AbstractSimpleAsn1Item): def __add__(self, value): value = self.prettyIn(value) - return self.clone(self.SizedInteger(self._value << len(value) | value).setBitLength(len(self._value) + len(value))) + return self.clone(SizedInteger(self._value << len(value) | value).setBitLength(len(self._value) + len(value))) def __radd__(self, value): value = self.prettyIn(value) - return self.clone(self.SizedInteger(value << len(self._value) | self._value).setBitLength(len(self._value) + len(value))) + return self.clone(SizedInteger(value << len(self._value) | self._value).setBitLength(len(self._value) + len(value))) def __mul__(self, value): bitString = self._value @@ -446,10 +447,10 @@ class BitString(base.AbstractSimpleAsn1Item): return self * value def __lshift__(self, count): - return self.clone(self.SizedInteger(self._value << count).setBitLength(len(self._value) + count)) + return self.clone(SizedInteger(self._value << count).setBitLength(len(self._value) + count)) def __rshift__(self, count): - return self.clone(self.SizedInteger(self._value >> count).setBitLength(max(0, len(self._value) - count))) + return self.clone(SizedInteger(self._value >> count).setBitLength(max(0, len(self._value) - count))) def __int__(self): return self._value @@ -498,14 +499,14 @@ class BitString(base.AbstractSimpleAsn1Item): Text string like 'DEADBEEF' """ try: - value = cls.SizedInteger(value, 16).setBitLength(len(value) * 4) + value = SizedInteger(value, 16).setBitLength(len(value) * 4) except ValueError: raise error.PyAsn1Error('%s.fromHexString() error: %s' % (cls.__name__, sys.exc_info()[1])) if prepend is not None: - value = cls.SizedInteger( - (cls.SizedInteger(prepend) << len(value)) | value + value = SizedInteger( + (SizedInteger(prepend) << len(value)) | value ).setBitLength(len(prepend) + len(value)) if not internalFormat: @@ -523,14 +524,14 @@ class BitString(base.AbstractSimpleAsn1Item): Text string like '1010111' """ try: - value = cls.SizedInteger(value or '0', 2).setBitLength(len(value)) + value = SizedInteger(value or '0', 2).setBitLength(len(value)) except ValueError: raise error.PyAsn1Error('%s.fromBinaryString() error: %s' % (cls.__name__, sys.exc_info()[1])) if prepend is not None: - value = cls.SizedInteger( - (cls.SizedInteger(prepend) << len(value)) | value + value = SizedInteger( + (SizedInteger(prepend) << len(value)) | value ).setBitLength(len(prepend) + len(value)) if not internalFormat: @@ -547,11 +548,11 @@ class BitString(base.AbstractSimpleAsn1Item): value: :class:`str` (Py2) or :class:`bytes` (Py3) Text string like '\\\\x01\\\\xff' (Py2) or b'\\\\x01\\\\xff' (Py3) """ - value = cls.SizedInteger(integer.from_bytes(value) >> padding).setBitLength(len(value) * 8 - padding) + value = SizedInteger(integer.from_bytes(value) >> padding).setBitLength(len(value) * 8 - padding) if prepend is not None: - value = cls.SizedInteger( - (cls.SizedInteger(prepend) << len(value)) | value + value = SizedInteger( + (SizedInteger(prepend) << len(value)) | value ).setBitLength(len(prepend) + len(value)) if not internalFormat: @@ -560,11 +561,11 @@ class BitString(base.AbstractSimpleAsn1Item): return value def prettyIn(self, value): - if isinstance(value, self.SizedInteger): + if isinstance(value, SizedInteger): return value elif octets.isStringType(value): if not value: - return self.SizedInteger(0).setBitLength(0) + return SizedInteger(0).setBitLength(0) elif value[0] == '\'': # "'1011'B" -- ASN.1 schema representation (deprecated) if value[-2:] == '\'B': @@ -592,7 +593,7 @@ class BitString(base.AbstractSimpleAsn1Item): for bitPosition in bitPositions: number |= 1 << (rightmostPosition - bitPosition) - return self.SizedInteger(number).setBitLength(rightmostPosition + 1) + return SizedInteger(number).setBitLength(rightmostPosition + 1) elif value.startswith('0x'): return self.fromHexString(value[2:], internalFormat=True) @@ -607,10 +608,10 @@ class BitString(base.AbstractSimpleAsn1Item): return self.fromBinaryString(''.join([b and '1' or '0' for b in value]), internalFormat=True) elif isinstance(value, BitString): - return self.SizedInteger(value).setBitLength(len(value)) + return SizedInteger(value).setBitLength(len(value)) elif isinstance(value, intTypes): - return self.SizedInteger(value) + return SizedInteger(value) else: raise error.PyAsn1Error( diff --git a/tests/type/test_char.py b/tests/type/test_char.py index 74550c0..1301a50 100644 --- a/tests/type/test_char.py +++ b/tests/type/test_char.py @@ -5,6 +5,7 @@ # License: http://pyasn1.sf.net/license.html # import sys +import pickle try: import unittest2 as unittest @@ -111,6 +112,21 @@ class AbstractStringTestCase(object): def testReverse(self): assert list(reversed(self.asn1String)) == list(reversed(self.pythonString)) + def testSchemaPickling(self): + old_asn1 = self.asn1Type() + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert type(new_asn1) == self.asn1Type + assert old_asn1.isSameTypeWith(new_asn1) + + def testValuePickling(self): + old_asn1 = self.asn1String + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert new_asn1 == self.asn1String + class VisibleStringTestCase(AbstractStringTestCase, BaseTestCase): diff --git a/tests/type/test_univ.py b/tests/type/test_univ.py index 166af44..36a3fd4 100644 --- a/tests/type/test_univ.py +++ b/tests/type/test_univ.py @@ -6,6 +6,7 @@ # import sys import math +import pickle try: import unittest2 as unittest @@ -296,6 +297,24 @@ class IntegerTestCase(BaseTestCase): ) +class IntegerPicklingTestCase(unittest.TestCase): + + def testSchemaPickling(self): + old_asn1 = univ.Integer() + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert type(new_asn1) == univ.Integer + assert old_asn1.isSameTypeWith(new_asn1) + + def testValuePickling(self): + old_asn1 = univ.Integer(-123) + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert new_asn1 == -123 + + class BooleanTestCase(BaseTestCase): def testTruth(self): assert univ.Boolean(True) and univ.Boolean(1), 'Truth initializer fails' @@ -328,6 +347,24 @@ class BooleanTestCase(BaseTestCase): assert 0, 'constraint fail' +class BooleanPicklingTestCase(unittest.TestCase): + + def testSchemaPickling(self): + old_asn1 = univ.Boolean() + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert type(new_asn1) == univ.Boolean + assert old_asn1.isSameTypeWith(new_asn1) + + def testValuePickling(self): + old_asn1 = univ.Boolean(True) + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert new_asn1 == True + + class BitStringTestCase(BaseTestCase): def setUp(self): BaseTestCase.setUp(self) @@ -407,6 +444,24 @@ class BitStringTestCase(BaseTestCase): assert BitString('11000000011001').asInteger() == 12313 +class BitStringPicklingTestCase(unittest.TestCase): + + def testSchemaPickling(self): + old_asn1 = univ.BitString() + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert type(new_asn1) == univ.BitString + assert old_asn1.isSameTypeWith(new_asn1) + + def testValuePickling(self): + old_asn1 = univ.BitString((1, 0, 1, 0)) + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert new_asn1 == (1, 0, 1, 0) + + class OctetStringWithUnicodeMixIn(object): initializer = () @@ -545,6 +600,24 @@ class OctetStringTestCase(BaseTestCase): assert OctetString(hexValue="FA9823C43E43510DE3422") == ints2octs((250, 152, 35, 196, 62, 67, 81, 13, 227, 66, 32)) +class OctetStringPicklingTestCase(unittest.TestCase): + + def testSchemaPickling(self): + old_asn1 = univ.BitString() + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert type(new_asn1) == univ.BitString + assert old_asn1.isSameTypeWith(new_asn1) + + def testValuePickling(self): + old_asn1 = univ.BitString((1, 0, 1, 0)) + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert new_asn1 == (1, 0, 1, 0) + + class Null(BaseTestCase): def testInit(self): @@ -594,6 +667,24 @@ class Null(BaseTestCase): assert not Null('') +class NullPicklingTestCase(unittest.TestCase): + + def testSchemaPickling(self): + old_asn1 = univ.Null() + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert type(new_asn1) == univ.Null + assert old_asn1.isSameTypeWith(new_asn1) + + def testValuePickling(self): + old_asn1 = univ.Null('') + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert not new_asn1 + + class RealTestCase(BaseTestCase): def testFloat4BinEnc(self): assert univ.Real((0.25, 2, 3)) == 2.0, 'float initializer for binary encoding fails' @@ -727,6 +818,24 @@ class RealTestCase(BaseTestCase): assert Real(1.0) == 1.0 +class RealPicklingTestCase(unittest.TestCase): + + def testSchemaPickling(self): + old_asn1 = univ.Real() + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert type(new_asn1) == univ.Real + assert old_asn1.isSameTypeWith(new_asn1) + + def testValuePickling(self): + old_asn1 = univ.Real((1, 10, 3)) + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert new_asn1 == 1000 + + class ObjectIdentifier(BaseTestCase): def testStr(self): assert str(univ.ObjectIdentifier((1, 3, 6))) == '1.3.6', 'str() fails' @@ -787,6 +896,24 @@ class ObjectIdentifier(BaseTestCase): assert str(ObjectIdentifier((1, 3, 6))) == '1.3.6' +class ObjectIdentifierPicklingTestCase(unittest.TestCase): + + def testSchemaPickling(self): + old_asn1 = univ.ObjectIdentifier() + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert type(new_asn1) == univ.ObjectIdentifier + assert old_asn1.isSameTypeWith(new_asn1) + + def testValuePickling(self): + old_asn1 = univ.ObjectIdentifier('2.3.1.1.2') + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert new_asn1 == (2, 3, 1, 1, 2) + + class SequenceOf(BaseTestCase): def setUp(self): BaseTestCase.setUp(self) @@ -1027,6 +1154,26 @@ class SequenceOf(BaseTestCase): assert s.getComponentByPosition(0, instantiate=False) is univ.noValue +class SequenceOfPicklingTestCase(unittest.TestCase): + + def testSchemaPickling(self): + old_asn1 = univ.SequenceOf(componentType=univ.OctetString()) + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert type(new_asn1) == univ.SequenceOf + assert old_asn1.isSameTypeWith(new_asn1) + + def testValuePickling(self): + old_asn1 = univ.SequenceOf(componentType=univ.OctetString()) + old_asn1[0] = 'test' + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert new_asn1 + assert new_asn1 == [str2octs('test')] + + class Sequence(BaseTestCase): def setUp(self): BaseTestCase.setUp(self) @@ -1329,6 +1476,34 @@ class SequenceWithoutSchema(BaseTestCase): assert 'field-0' not in s +class SequencePicklingTestCase(unittest.TestCase): + + def testSchemaPickling(self): + old_asn1 = univ.Sequence( + componentType=namedtype.NamedTypes( + namedtype.NamedType('name', univ.OctetString()) + ) + ) + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert type(new_asn1) == univ.Sequence + assert old_asn1.isSameTypeWith(new_asn1) + + def testValuePickling(self): + old_asn1 = univ.Sequence( + componentType=namedtype.NamedTypes( + namedtype.NamedType('name', univ.OctetString()) + ) + ) + old_asn1['name'] = 'test' + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert new_asn1 + assert new_asn1['name'] == str2octs('test') + + class SetOf(BaseTestCase): def setUp(self): BaseTestCase.setUp(self) @@ -1357,6 +1532,27 @@ class SetOf(BaseTestCase): assert s == [str2octs('abc')] + +class SetOfPicklingTestCase(unittest.TestCase): + + def testSchemaPickling(self): + old_asn1 = univ.SetOf(componentType=univ.OctetString()) + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert type(new_asn1) == univ.SetOf + assert old_asn1.isSameTypeWith(new_asn1) + + def testValuePickling(self): + old_asn1 = univ.SetOf(componentType=univ.OctetString()) + old_asn1[0] = 'test' + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert new_asn1 + assert new_asn1 == [str2octs('test')] + + class Set(BaseTestCase): def setUp(self): BaseTestCase.setUp(self) @@ -1443,6 +1639,34 @@ class Set(BaseTestCase): assert s.getComponentByPosition(1, instantiate=False) is univ.noValue +class SetPicklingTestCase(unittest.TestCase): + + def testSchemaPickling(self): + old_asn1 = univ.Set( + componentType=namedtype.NamedTypes( + namedtype.NamedType('name', univ.OctetString()) + ) + ) + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert type(new_asn1) == univ.Set + assert old_asn1.isSameTypeWith(new_asn1) + + def testValuePickling(self): + old_asn1 = univ.Set( + componentType=namedtype.NamedTypes( + namedtype.NamedType('name', univ.OctetString()) + ) + ) + old_asn1['name'] = 'test' + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert new_asn1 + assert new_asn1['name'] == str2octs('test') + + class Choice(BaseTestCase): def setUp(self): BaseTestCase.setUp(self) @@ -1590,6 +1814,36 @@ class Choice(BaseTestCase): assert s.getComponentByPosition(1, instantiate=False) is univ.noValue +class ChoicePicklingTestCase(unittest.TestCase): + + def testSchemaPickling(self): + old_asn1 = univ.Choice( + componentType=namedtype.NamedTypes( + namedtype.NamedType('name', univ.OctetString()), + namedtype.NamedType('id', univ.Integer()) + ) + ) + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert type(new_asn1) == univ.Choice + assert old_asn1.isSameTypeWith(new_asn1) + + def testValuePickling(self): + old_asn1 = univ.Choice( + componentType=namedtype.NamedTypes( + namedtype.NamedType('name', univ.OctetString()), + namedtype.NamedType('id', univ.Integer()) + ) + ) + old_asn1['name'] = 'test' + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert new_asn1 + assert new_asn1['name'] == str2octs('test') + + suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__]) if __name__ == '__main__': diff --git a/tests/type/test_useful.py b/tests/type/test_useful.py index 82a97d7..2af17ff 100644 --- a/tests/type/test_useful.py +++ b/tests/type/test_useful.py @@ -7,6 +7,7 @@ import sys import datetime from copy import deepcopy +import pickle try: import unittest2 as unittest @@ -78,6 +79,24 @@ class GeneralizedTimeTestCase(BaseTestCase): assert dt == deepcopy(dt) +class GeneralizedTimePicklingTestCase(unittest.TestCase): + + def testSchemaPickling(self): + old_asn1 = useful.GeneralizedTime() + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert type(new_asn1) == useful.GeneralizedTime + assert old_asn1.isSameTypeWith(new_asn1) + + def testValuePickling(self): + old_asn1 = useful.GeneralizedTime("20170916234254+0130") + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert new_asn1 == old_asn1 + + class UTCTimeTestCase(BaseTestCase): def testFromDateTime(self): @@ -98,6 +117,25 @@ class UTCTimeTestCase(BaseTestCase): def testToDateTime4(self): assert datetime.datetime(2017, 7, 11, 0, 1) == useful.UTCTime('1707110001').asDateTime + +class UTCTimePicklingTestCase(unittest.TestCase): + + def testSchemaPickling(self): + old_asn1 = useful.UTCTime() + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert type(new_asn1) == useful.UTCTime + assert old_asn1.isSameTypeWith(new_asn1) + + def testValuePickling(self): + old_asn1 = useful.UTCTime("170711000102") + serialized = pickle.dumps(old_asn1) + assert serialized + new_asn1 = pickle.loads(serialized) + assert new_asn1 == old_asn1 + + suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__]) if __name__ == '__main__': -- cgit v1.2.3