diff options
author | Ilya Etingof <etingof@gmail.com> | 2017-02-24 18:13:04 +0100 |
---|---|---|
committer | Ilya Etingof <etingof@gmail.com> | 2017-02-24 18:13:04 +0100 |
commit | 679408903afee6d76b682376aedf0eeee496b627 (patch) | |
tree | 7e49e98eae8272e5d3482bf2b24161c8f5535857 /pyasn1 | |
parent | 61ceeae369e0e41b6b8e57437430200418b51bef (diff) | |
download | pyasn1-679408903afee6d76b682376aedf0eeee496b627.tar.gz |
BitString type and codecs reimplemented
Targeting high-performance and convenience to use.
Sampling just BitString en/decoding performance -- new implementation is 100x faster. ;-\
Diffstat (limited to 'pyasn1')
-rw-r--r-- | pyasn1/codec/ber/decoder.py | 72 | ||||
-rw-r--r-- | pyasn1/codec/ber/encoder.py | 105 | ||||
-rw-r--r-- | pyasn1/compat/integer.py | 96 | ||||
-rw-r--r-- | pyasn1/type/univ.py | 273 |
4 files changed, 296 insertions, 250 deletions
diff --git a/pyasn1/codec/ber/decoder.py b/pyasn1/codec/ber/decoder.py index d01754f..7d1735d 100644 --- a/pyasn1/codec/ber/decoder.py +++ b/pyasn1/codec/ber/decoder.py @@ -4,10 +4,10 @@ # Copyright (c) 2005-2017, Ilya Etingof <etingof@gmail.com> # License: http://pyasn1.sf.net/license.html # -from sys import version_info from pyasn1.type import base, tag, univ, char, useful, tagmap from pyasn1.codec.ber import eoo -from pyasn1.compat.octets import str2octs, oct2int, isOctetsType +from pyasn1.compat.octets import oct2int, isOctetsType +from pyasn1.compat.integer import from_bytes from pyasn1 import debug, error __all__ = ['decode'] @@ -89,21 +89,6 @@ explicitTagDecoder = ExplicitTagDecoder() class IntegerDecoder(AbstractSimpleDecoder): protoComponent = univ.Integer(0) - if version_info[0:2] < (3, 2): - @staticmethod - def _from_octets(octets, signed=False): - value = long(octets.encode('hex'), 16) - - if signed and oct2int(octets[0]) & 0x80: - return value - (1 << len(octets) * 8) - - return value - - else: - @staticmethod - def _from_octets(octets, signed=False): - return int.from_bytes(octets, 'big', signed=signed) - def valueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, length, state, decodeFun, substrateFun): head, tail = substrate[:length], substrate[length:] @@ -111,7 +96,7 @@ class IntegerDecoder(AbstractSimpleDecoder): if not head: return self._createComponent(asn1Spec, tagSet, 0), tail - value = self._from_octets(head, signed=True) + value = from_bytes(head, signed=True) return self._createComponent(asn1Spec, tagSet, value), tail @@ -140,46 +125,41 @@ class BitStringDecoder(AbstractSimpleDecoder): 'Trailing bits overflow %s' % trailingBits ) head = head[1:] - lsb = p = 0 - l = len(head) - 1 - b = [] - while p <= l: - if p == l: - lsb = trailingBits - j = 7 - o = oct2int(head[p]) - while j >= lsb: - b.append((o >> j) & 0x01) - j -= 1 - p += 1 - return self._createComponent(asn1Spec, tagSet, b), tail + value = self.protoComponent.fromOctetString(head, trailingBits) + return self._createComponent(asn1Spec, tagSet, value), tail + if not self.supportConstructedForm: raise error.PyAsn1Error('Constructed encoding form prohibited at %s' % self.__class__.__name__) - r = self._createComponent(asn1Spec, tagSet, ()) + + bitString = self._createComponent(asn1Spec, tagSet) + if substrateFun: - return substrateFun(r, substrate, length) + return substrateFun(bitString, substrate, length) + while head: component, head = decodeFun(head, self.protoComponent) - r = r + component - return r, tail + bitString += component + + return bitString, tail def indefLenValueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, length, state, decodeFun, substrateFun): - r = self._createComponent(asn1Spec, tagSet, '') + bitString = self._createComponent(asn1Spec, tagSet) + if substrateFun: - return substrateFun(r, substrate, length) + return substrateFun(bitString, substrate, length) + while substrate: - component, substrate = decodeFun(substrate, self.protoComponent, - allowEoo=True) - if eoo.endOfOctets.isSameTypeWith(component) and \ - component == eoo.endOfOctets: + component, substrate = decodeFun(substrate, self.protoComponent, allowEoo=True) + if eoo.endOfOctets.isSameTypeWith(component) and component == eoo.endOfOctets: break - r = r + component + + bitString += component + else: - raise error.SubstrateUnderrunError( - 'No EOO seen before substrate ends' - ) - return r, substrate + raise error.SubstrateUnderrunError('No EOO seen before substrate ends') + + return bitString, substrate class OctetStringDecoder(AbstractSimpleDecoder): diff --git a/pyasn1/codec/ber/encoder.py b/pyasn1/codec/ber/encoder.py index d69e283..223554d 100644 --- a/pyasn1/codec/ber/encoder.py +++ b/pyasn1/codec/ber/encoder.py @@ -4,12 +4,10 @@ # Copyright (c) 2005-2017, Ilya Etingof <etingof@gmail.com> # License: http://pyasn1.sf.net/license.html # -from sys import version_info -if version_info[0] < 3: - from binascii import a2b_hex from pyasn1.type import base, tag, univ, char, useful from pyasn1.codec.ber import eoo from pyasn1.compat.octets import int2oct, oct2int, ints2octs, null, str2octs +from pyasn1.compat.integer import to_bytes from pyasn1 import debug, error __all__ = ['encode'] @@ -107,68 +105,6 @@ class IntegerEncoder(AbstractItemEncoder): supportCompactZero = False encodedZero = ints2octs((0,)) - if version_info[0:2] > (3, 1): - @staticmethod - def _to_octets(value, signed=False): - length = value.bit_length() - - if signed and length % 8 == 0: - length += 1 - - return value.to_bytes(length // 8 + 1, 'big', signed=signed) - - else: - @staticmethod - def _to_octets(value, signed=False): - if value < 0: - if signed: - # bits in unsigned number - hexValue = hex(abs(value)) - bits = len(hexValue) - 2 - if hexValue.endswith('L'): - bits -= 1 - if bits & 1: - bits += 1 - bits *= 4 - - # two's complement form - maxValue = 1 << bits - valueToEncode = (value + maxValue) % maxValue - - else: - raise OverflowError('can\'t convert negative int to unsigned') - else: - valueToEncode = value - - hexValue = hex(valueToEncode)[2:] - if hexValue.endswith('L'): - hexValue = hexValue[:-1] - - if len(hexValue) & 1: - hexValue = '0' + hexValue - - # padding may be needed for two's complement encoding - if value != valueToEncode: - hexLength = len(hexValue) // 2 - - padLength = bits // 8 - - if padLength > hexLength: - hexValue = '00' * (padLength - hexLength) + hexValue - - firstOctet = int(hexValue[:2], 16) - - if signed: - if firstOctet & 0x80: - if value >= 0: - hexValue = '00' + hexValue - elif value < 0: - hexValue = 'ff' + hexValue - - octets_value = a2b_hex(hexValue) - - return octets_value - def encodeValue(self, encodeFun, value, defMode, maxChunkSize): if value == 0: # de-facto way to encode zero @@ -177,34 +113,27 @@ class IntegerEncoder(AbstractItemEncoder): else: return self.encodedZero, 0 - return self._to_octets(int(value), signed=True), 0 + return to_bytes(int(value), signed=True), 0 class BitStringEncoder(AbstractItemEncoder): def encodeValue(self, encodeFun, value, defMode, maxChunkSize): - if not maxChunkSize or len(value) <= maxChunkSize * 8: - out_len = (len(value) + 7) // 8 - out_list = out_len * [0] - j = 7 - i = -1 - for val in value: - j += 1 - if j == 8: - i += 1 - j = 0 - out_list[i] |= val << (7 - j) - return int2oct(7 - j) + ints2octs(out_list), 0 + if len(value) % 8: + alignedValue = value << (8 - len(value) % 8) else: - pos = 0 - substrate = null - while True: - # count in octets - v = value.clone(value[pos * 8:pos * 8 + maxChunkSize * 8]) - if not v: - break - substrate = substrate + encodeFun(v, defMode, maxChunkSize) - pos += maxChunkSize - return substrate, 1 + alignedValue = value + + if not maxChunkSize or len(alignedValue) <= maxChunkSize * 8: + substrate = alignedValue.asOctets() + return int2oct(len(substrate) * 8 - len(value)) + substrate, 0 + + stop = 0 + substrate = null + while stop < len(value): + start = stop + stop = min(start + maxChunkSize * 8, len(value)) + substrate += encodeFun(alignedValue[start:stop], defMode, maxChunkSize) + return substrate, 1 class OctetStringEncoder(AbstractItemEncoder): diff --git a/pyasn1/compat/integer.py b/pyasn1/compat/integer.py new file mode 100644 index 0000000..ae9c7e1 --- /dev/null +++ b/pyasn1/compat/integer.py @@ -0,0 +1,96 @@ +# +# This file is part of pyasn1 software. +# +# Copyright (c) 2005-2017, Ilya Etingof <etingof@gmail.com> +# License: http://pyasn1.sf.net/license.html +# +import sys +if sys.version_info[0:2] < (3, 2): + from binascii import a2b_hex, b2a_hex +from pyasn1.compat.octets import oct2int, null + +if sys.version_info[0:2] < (3, 2): + def from_bytes(octets, signed=False): + value = long(b2a_hex(str(octets)), 16) + + if signed and oct2int(octets[0]) & 0x80: + return value - (1 << len(octets) * 8) + + return value + + def to_bytes(value, signed=False, length=0): + if value < 0: + if signed: + bits = bitLength(value) + + # two's complement form + maxValue = 1 << bits + valueToEncode = (value + maxValue) % maxValue + + else: + raise OverflowError('can\'t convert negative int to unsigned') + elif value == 0 and length == 0: + return null + else: + bits = 0 + valueToEncode = value + + hexValue = hex(valueToEncode)[2:] + if hexValue.endswith('L'): + hexValue = hexValue[:-1] + + if len(hexValue) & 1: + hexValue = '0' + hexValue + + # padding may be needed for two's complement encoding + if value != valueToEncode or length: + hexLength = len(hexValue) * 4 + + padLength = max(length, bits) + + if padLength > hexLength: + hexValue = '00' * ((padLength - hexLength - 1) // 8 + 1) + hexValue + elif length and hexLength - length > 7: + raise OverflowError('int too big to convert') + + firstOctet = int(hexValue[:2], 16) + + if signed: + if firstOctet & 0x80: + if value >= 0: + hexValue = '00' + hexValue + elif value < 0: + hexValue = 'ff' + hexValue + + octets_value = a2b_hex(hexValue) + + return octets_value + + def bitLength(number): + # bits in unsigned number + hexValue = hex(abs(number)) + bits = len(hexValue) - 2 + if hexValue.endswith('L'): + bits -= 1 + if bits & 1: + bits += 1 + bits *= 4 + # TODO: strip lhs zeros + return bits + +else: + + def from_bytes(octets, signed=False): + return int.from_bytes(bytes(octets), 'big', signed=signed) + + def to_bytes(value, signed=False, length=0): + length = max(value.bit_length(), length) + + if signed and length % 8 == 0: + length += 1 + + return value.to_bytes(length // 8 + (length % 8 and 1 or 0), 'big', signed=signed) + + def bitLength(number): + return int(number).bit_length() + diff --git a/pyasn1/type/univ.py b/pyasn1/type/univ.py index de50eab..e6843a3 100644 --- a/pyasn1/type/univ.py +++ b/pyasn1/type/univ.py @@ -9,7 +9,7 @@ import sys import math from pyasn1.type import base, tag, constraint, namedtype, namedval, tagmap from pyasn1.codec.ber import eoo -from pyasn1.compat import octets +from pyasn1.compat import octets, integer, binary from pyasn1 import error NoValue = base.NoValue @@ -361,12 +361,18 @@ class Boolean(Integer): class BitString(base.AbstractSimpleAsn1Item): """Create |ASN.1| type or object. - |ASN.1| objects are immutable and duck-type Python :class:`tuple` objects (tuple of bits). + |ASN.1| objects are immutable and duck-type both Python :class:`tuple` (as a tuple + of bits) and :class:`int` objects. + + |ASN.1| objects can be initialized from a string literal or a sequence of bits + or an integer. Then |ASN.1| objects can be worked as with a sequence (including + concatenation) or a number (including bitshifting). Parameters ---------- value : :class:`int`, :class:`str` or |ASN.1| object - Python integer or string literal or |ASN.1| object. + Python integer or string literal representing binary or hexadecimal + number or sequence of integer bits or |ASN.1| object. tagSet: :py:class:`~pyasn1.type.tag.TagSet` Object representing non-default ASN.1 tag(s) @@ -406,6 +412,25 @@ class BitString(base.AbstractSimpleAsn1Item): defaultBinValue = defaultHexValue = noValue + if sys.version_info[0] < 3: + SizedIntegerBase = long + else: + SizedIntegerBase = int + + class SizedInteger(SizedIntegerBase): + bitLength = leadingZeroBits = None + + def setBitLength(self, bitLength): + self.bitLength = bitLength + self.leadingZeroBits = max(bitLength - integer.bitLength(self), 0) + return self + + def __len__(self): + if self.bitLength is None: + self.setBitLength(integer.bitLength(self)) + + return self.bitLength + def __init__(self, value=noValue, tagSet=None, subtypeSpec=None, namedValues=None, binValue=noValue, hexValue=noValue): if namedValues is None: @@ -539,138 +564,145 @@ class BitString(base.AbstractSimpleAsn1Item): return self.__class__(value, tagSet, subtypeSpec, namedValues, binValue, hexValue) def __str__(self): - return ''.join([str(x) for x in self._value]) + return self.asBinary() + + def __eq__(self, other): + other = self.prettyIn(other) + return self is other or self._value == other and len(self._value) == len(other) + + def __ne__(self, other): + other = self.prettyIn(other) + return self._value != other or len(self._value) != len(other) + + def __lt__(self, other): + other = self.prettyIn(other) + return len(self._value) < len(other) or len(self._value) == len(other) and self._value < other + + def __le__(self, other): + other = self.prettyIn(other) + return len(self._value) <= len(other) or len(self._value) == len(other) and self._value <= other + + def __gt__(self, other): + other = self.prettyIn(other) + return len(self._value) > len(other) or len(self._value) == len(other) and self._value > other + + def __ge__(self, other): + other = self.prettyIn(other) + return len(self._value) >= len(other) or len(self._value) == len(other) and self._value >= other # Immutable sequence object protocol def __len__(self): - if self._len is None: - self._len = len(self._value) - return self._len + return len(self._value) def __getitem__(self, i): if isinstance(i, slice): - return self.clone(operator.getitem(self._value, i)) + return self.clone([self[x] for x in range(*i.indices(len(self)))]) else: - return self._value[i] + length = len(self._value) - 1 + if i > length or i < 0: + raise IndexError('bit index out of range') + return (self._value >> (length - i)) & 1 - def __contains__(self, bit): - return bit in self._value + def __iter__(self): + length = len(self._value) + while length: + length -= 1 + yield (self._value >> length) & 1 def __reversed__(self): - return reversed(self._value) + return reversed(tuple(self)) + + # arithmetic operators def __add__(self, value): - return self.clone(self._value + value) + value = self.prettyIn(value) + return self.clone(self.SizedInteger(self._value << len(value) | value).setBitLength(len(self._value) + len(value))) def __radd__(self, value): - return self.clone(value + self._value) + value = self.prettyIn(value) + return self.clone(self.SizedInteger(value << len(self._value) | self._value).setBitLength(len(self._value) + len(value))) def __mul__(self, value): - return self.clone(self._value * value) + bitString = self._value + while value > 1: + bitString <<= len(self._value) + bitString |= self._value + value -= 1 + return self.clone(bitString) def __rmul__(self, value): return self * value - def asNumbers(self, padding=True): - """Get |ASN.1| value as a sequence of 8-bit integers. + def __lshift__(self, count): + return self.clone(self.SizedInteger(self._value << count).setBitLength(len(self._value) + count)) - Parameters - ---------- - padding: :class:`bool` - Allow left-padding if |ASN.1| value length is not a multiples of eight. + def __rshift__(self, count): + return self.clone(self.SizedInteger(self._value >> count).setBitLength(max(0, len(self._value) - count))) - Raises - ------ - : :py:class:`pyasn1.error.PyAsn1Error` - If |ASN.1| value length is not multiples of eight and no padding is allowed. - """ - if not padding and len(self) % 8 != 0: - raise error.PyAsn1Error('BIT STRING length is not a multiple of 8') + def __int__(self): + return self._value - if padding in self.__asNumbersCache: - return self.__asNumbersCache[padding] + def __float__(self): + return float(self._value) - result = [] - bitstring = list(self) - while len(bitstring) % 8: - bitstring.insert(0, 0) - bitIndex = 0 - while bitIndex < len(bitstring): - byte = 0 - for x in range(8): - byte |= bitstring[bitIndex + x] << (7 - x) - result.append(byte) - bitIndex += 8 + if sys.version_info[0] < 3: + def __long__(self): + return self._value - self.__asNumbersCache[padding] = tuple(result) + def asNumbers(self): + """Get |ASN.1| value as a sequence of 8-bit integers. - return self.__asNumbersCache[padding] + If |ASN.1| object length is not a multiple of 8, result + will be left-padded with zeros. + """ + return tuple(octets.octs2ints(self.asOctets())) - def asOctets(self, padding=True): + def asOctets(self): """Get |ASN.1| value as a sequence of octets. - Parameters - ---------- - padding: :class:`bool` - Allow left-padding if |ASN.1| value length is not a multiples of eight. - - Raises - ------ - : :py:class:`pyasn1.error.PyAsn1Error` - If |ASN.1| value length is not multiples of eight and no padding is allowed. + If |ASN.1| object length is not a multiple of 8, result + will be left-padded with zeros. """ - return octets.ints2octs(self.asNumbers(padding)) + return integer.to_bytes(self._value, length=len(self)) - def asInteger(self, padding=True): + def asInteger(self): """Get |ASN.1| value as a single integer value. + """ + return self._value - Parameters - ---------- - padding: :class:`bool` - Allow left-padding if |ASN.1| value length is not a multiples of eight. - - Raises - ------ - : :py:class:`pyasn1.error.PyAsn1Error` - If |ASN.1| value length is not multiples of eight and no padding is allowed. + def asBinary(self): + """Get |ASN.1| value as a text string of bits. """ - accumulator = 0 - for byte in self.asNumbers(padding): - accumulator <<= 8 - accumulator |= byte - return accumulator + binString = binary.bin(self._value)[2:] + return '0'*(len(self._value) - len(binString)) + binString - @staticmethod - def fromHexString(value): - r = [] - for v in value: - v = int(v, 16) - i = 4 - while i: - i -= 1 - r.append((v >> i) & 0x01) - return tuple(r) + @classmethod + def fromHexString(cls, value): + try: + return cls.SizedInteger(value, 16).setBitLength(len(value) * 4) - @staticmethod - def fromBinaryString(value): - r = [] - for v in value: - if v in ('0', '1'): - r.append(int(v)) - else: - raise error.PyAsn1Error( - 'Non-binary BIT STRING initializer %s' % (v,) - ) - return tuple(r) + except ValueError: + raise error.PyAsn1Error('%s.fromHexString() error: %s' % (cls.__name__, sys.exc_info()[1])) + + @classmethod + def fromBinaryString(cls, value): + try: + return cls.SizedInteger(value or '0', 2).setBitLength(len(value)) + + except ValueError: + raise error.PyAsn1Error('%s.fromBinaryString() error: %s' % (cls.__name__, sys.exc_info()[1])) + + @classmethod + def fromOctetString(cls, value, padding=0): + return cls(cls.SizedInteger(integer.from_bytes(value) >> padding).setBitLength(len(value) * 8 - padding)) def prettyIn(self, value): - r = [] - if not value: - return () + if octets.isStringType(value): + if not value: + return self.SizedInteger(0).setBitLength(0) - elif octets.isStringType(value): - if value[0] == '\'': # "'1011'B" -- ASN.1 schema representation + elif value[0] == '\'': # "'1011'B" -- ASN.1 schema representation (deprecated) if value[-2:] == '\'B': return self.fromBinaryString(value[1:-2]) elif value[-2:] == '\'H': @@ -679,31 +711,40 @@ class BitString(base.AbstractSimpleAsn1Item): raise error.PyAsn1Error( 'Bad BIT STRING value notation %s' % (value,) ) + elif self.__namedValues and not value.isdigit(): # named bits like 'Urgent, Active' - for i in value.split(','): - j = self.__namedValues.getValue(i) - if j is None: + number = 0 + highestBitPosition = 0 + for namedBit in value.split(','): + bitPosition = self.__namedValues.getValue(namedBit) + if bitPosition is None: raise error.PyAsn1Error( - 'Unknown bit identifier \'%s\'' % (i,) + 'Unknown bit identifier \'%s\'' % (namedBit,) ) - if j >= len(r): - r.extend([0] * (j - len(r) + 1)) - r[j] = 1 - return tuple(r) + + number |= (1 << bitPosition) + + highestBitPosition = max(highestBitPosition, bitPosition) + + return self.SizedInteger(number).setBitLength(highestBitPosition + 1) + + elif value.startswith('0x'): + return self.fromHexString(value[2:]) + + elif value.startswith('0b'): + return self.fromBinaryString(value[2:]) + else: # assume plain binary string like '1011' return self.fromBinaryString(value) elif isinstance(value, (tuple, list)): - r = tuple(value) - for b in r: - if b and b != 1: - raise error.PyAsn1Error( - 'Non-binary BitString initializer \'%s\'' % (r,) - ) - return r + return self.fromBinaryString(''.join([b and '1' or '0' for b in value])) - elif isinstance(value, BitString): - return tuple(value) + elif isinstance(value, (self.SizedInteger, BitString)): + return self.SizedInteger(value).setBitLength(len(value)) + + elif isinstance(value, intTypes): + return self.SizedInteger(value) else: raise error.PyAsn1Error( @@ -941,10 +982,10 @@ class OctetString(base.AbstractSimpleAsn1Item): 'Can\'t decode string \'%s\' with \'%s\' codec' % (self._value, self._encoding) ) - def asOctets(self, padding=True): + def asOctets(self): return str(self._value) - def asNumbers(self, padding=True): + def asNumbers(self): if self.__asNumbersCache is None: self.__asNumbersCache = tuple([ord(x) for x in self._value]) return self.__asNumbersCache @@ -981,10 +1022,10 @@ class OctetString(base.AbstractSimpleAsn1Item): def __bytes__(self): return bytes(self._value) - def asOctets(self, padding=True): + def asOctets(self): return bytes(self._value) - def asNumbers(self, padding=True): + def asNumbers(self): if self.__asNumbersCache is None: self.__asNumbersCache = tuple(self._value) return self.__asNumbersCache |