diff options
Diffstat (limited to 'generator/google/protobuf/internal/containers.py')
-rw-r--r-- | generator/google/protobuf/internal/containers.py | 384 |
1 files changed, 365 insertions, 19 deletions
diff --git a/generator/google/protobuf/internal/containers.py b/generator/google/protobuf/internal/containers.py index 34b35f8..ce46d08 100644 --- a/generator/google/protobuf/internal/containers.py +++ b/generator/google/protobuf/internal/containers.py @@ -1,6 +1,6 @@ # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ +# https://developers.google.com/protocol-buffers/ # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -41,6 +41,146 @@ are: __author__ = 'petar@google.com (Petar Petrov)' +import collections +import sys + +if sys.version_info[0] < 3: + # We would use collections.MutableMapping all the time, but in Python 2 it + # doesn't define __slots__. This causes two significant problems: + # + # 1. we can't disallow arbitrary attribute assignment, even if our derived + # classes *do* define __slots__. + # + # 2. we can't safely derive a C type from it without __slots__ defined (the + # interpreter expects to find a dict at tp_dictoffset, which we can't + # robustly provide. And we don't want an instance dict anyway. + # + # So this is the Python 2.7 definition of Mapping/MutableMapping functions + # verbatim, except that: + # 1. We declare __slots__. + # 2. We don't declare this as a virtual base class. The classes defined + # in collections are the interesting base classes, not us. + # + # Note: deriving from object is critical. It is the only thing that makes + # this a true type, allowing us to derive from it in C++ cleanly and making + # __slots__ properly disallow arbitrary element assignment. + + class Mapping(object): + __slots__ = () + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def __contains__(self, key): + try: + self[key] + except KeyError: + return False + else: + return True + + def iterkeys(self): + return iter(self) + + def itervalues(self): + for key in self: + yield self[key] + + def iteritems(self): + for key in self: + yield (key, self[key]) + + def keys(self): + return list(self) + + def items(self): + return [(key, self[key]) for key in self] + + def values(self): + return [self[key] for key in self] + + # Mappings are not hashable by default, but subclasses can change this + __hash__ = None + + def __eq__(self, other): + if not isinstance(other, collections.Mapping): + return NotImplemented + return dict(self.items()) == dict(other.items()) + + def __ne__(self, other): + return not (self == other) + + class MutableMapping(Mapping): + __slots__ = () + + __marker = object() + + def pop(self, key, default=__marker): + try: + value = self[key] + except KeyError: + if default is self.__marker: + raise + return default + else: + del self[key] + return value + + def popitem(self): + try: + key = next(iter(self)) + except StopIteration: + raise KeyError + value = self[key] + del self[key] + return key, value + + def clear(self): + try: + while True: + self.popitem() + except KeyError: + pass + + def update(*args, **kwds): + if len(args) > 2: + raise TypeError("update() takes at most 2 positional " + "arguments ({} given)".format(len(args))) + elif not args: + raise TypeError("update() takes at least 1 argument (0 given)") + self = args[0] + other = args[1] if len(args) >= 2 else () + + if isinstance(other, Mapping): + for key in other: + self[key] = other[key] + elif hasattr(other, "keys"): + for key in other.keys(): + self[key] = other[key] + else: + for key, value in other: + self[key] = value + for key, value in kwds.items(): + self[key] = value + + def setdefault(self, key, default=None): + try: + return self[key] + except KeyError: + self[key] = default + return default + + collections.Mapping.register(Mapping) + collections.MutableMapping.register(MutableMapping) + +else: + # In Python 3 we can just use MutableMapping directly, because it defines + # __slots__. + MutableMapping = collections.MutableMapping + class BaseContainer(object): @@ -108,29 +248,34 @@ class RepeatedScalarFieldContainer(BaseContainer): def append(self, value): """Appends an item to the list. Similar to list.append().""" - self._type_checker.CheckValue(value) - self._values.append(value) + self._values.append(self._type_checker.CheckValue(value)) if not self._message_listener.dirty: self._message_listener.Modified() def insert(self, key, value): """Inserts the item at the specified position. Similar to list.insert().""" - self._type_checker.CheckValue(value) - self._values.insert(key, value) + self._values.insert(key, self._type_checker.CheckValue(value)) if not self._message_listener.dirty: self._message_listener.Modified() def extend(self, elem_seq): - """Extends by appending the given sequence. Similar to list.extend().""" - if not elem_seq: - return + """Extends by appending the given iterable. Similar to list.extend().""" - new_values = [] - for elem in elem_seq: - self._type_checker.CheckValue(elem) - new_values.append(elem) - self._values.extend(new_values) - self._message_listener.Modified() + if elem_seq is None: + return + try: + elem_seq_iter = iter(elem_seq) + except TypeError: + if not elem_seq: + # silently ignore falsy inputs :-/. + # TODO(ptucker): Deprecate this behavior. b/18413862 + return + raise + + new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter] + if new_values: + self._values.extend(new_values) + self._message_listener.Modified() def MergeFrom(self, other): """Appends the contents of another repeated field of the same type to this @@ -144,11 +289,21 @@ class RepeatedScalarFieldContainer(BaseContainer): self._values.remove(elem) self._message_listener.Modified() + def pop(self, key=-1): + """Removes and returns an item at a given index. Similar to list.pop().""" + value = self._values[key] + self.__delitem__(key) + return value + def __setitem__(self, key, value): """Sets the item on the specified position.""" - self._type_checker.CheckValue(value) - self._values[key] = value - self._message_listener.Modified() + if isinstance(key, slice): # PY3 + if key.step is not None: + raise ValueError('Extended slices not supported') + self.__setslice__(key.start, key.stop, value) + else: + self._values[key] = self._type_checker.CheckValue(value) + self._message_listener.Modified() def __getslice__(self, start, stop): """Retrieves the subset of items from between the specified indices.""" @@ -158,8 +313,7 @@ class RepeatedScalarFieldContainer(BaseContainer): """Sets the subset of items from between the specified indices.""" new_values = [] for value in values: - self._type_checker.CheckValue(value) - new_values.append(value) + new_values.append(self._type_checker.CheckValue(value)) self._values[start:stop] = new_values self._message_listener.Modified() @@ -183,6 +337,8 @@ class RepeatedScalarFieldContainer(BaseContainer): # We are presumably comparing against some other sequence type. return other == self._values +collections.MutableSequence.register(BaseContainer) + class RepeatedCompositeFieldContainer(BaseContainer): @@ -245,6 +401,12 @@ class RepeatedCompositeFieldContainer(BaseContainer): self._values.remove(elem) self._message_listener.Modified() + def pop(self, key=-1): + """Removes and returns an item at a given index. Similar to list.pop().""" + value = self._values[key] + self.__delitem__(key) + return value + def __getslice__(self, start, stop): """Retrieves the subset of items from between the specified indices.""" return self._values[start:stop] @@ -267,3 +429,187 @@ class RepeatedCompositeFieldContainer(BaseContainer): raise TypeError('Can only compare repeated composite fields against ' 'other repeated composite fields.') return self._values == other._values + + +class ScalarMap(MutableMapping): + + """Simple, type-checked, dict-like container for holding repeated scalars.""" + + # Disallows assignment to other attributes. + __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener'] + + def __init__(self, message_listener, key_checker, value_checker): + """ + Args: + message_listener: A MessageListener implementation. + The ScalarMap will call this object's Modified() method when it + is modified. + key_checker: A type_checkers.ValueChecker instance to run on keys + inserted into this container. + value_checker: A type_checkers.ValueChecker instance to run on values + inserted into this container. + """ + self._message_listener = message_listener + self._key_checker = key_checker + self._value_checker = value_checker + self._values = {} + + def __getitem__(self, key): + try: + return self._values[key] + except KeyError: + key = self._key_checker.CheckValue(key) + val = self._value_checker.DefaultValue() + self._values[key] = val + return val + + def __contains__(self, item): + # We check the key's type to match the strong-typing flavor of the API. + # Also this makes it easier to match the behavior of the C++ implementation. + self._key_checker.CheckValue(item) + return item in self._values + + # We need to override this explicitly, because our defaultdict-like behavior + # will make the default implementation (from our base class) always insert + # the key. + def get(self, key, default=None): + if key in self: + return self[key] + else: + return default + + def __setitem__(self, key, value): + checked_key = self._key_checker.CheckValue(key) + checked_value = self._value_checker.CheckValue(value) + self._values[checked_key] = checked_value + self._message_listener.Modified() + + def __delitem__(self, key): + del self._values[key] + self._message_listener.Modified() + + def __len__(self): + return len(self._values) + + def __iter__(self): + return iter(self._values) + + def __repr__(self): + return repr(self._values) + + def MergeFrom(self, other): + self._values.update(other._values) + self._message_listener.Modified() + + def InvalidateIterators(self): + # It appears that the only way to reliably invalidate iterators to + # self._values is to ensure that its size changes. + original = self._values + self._values = original.copy() + original[None] = None + + # This is defined in the abstract base, but we can do it much more cheaply. + def clear(self): + self._values.clear() + self._message_listener.Modified() + + +class MessageMap(MutableMapping): + + """Simple, type-checked, dict-like container for with submessage values.""" + + # Disallows assignment to other attributes. + __slots__ = ['_key_checker', '_values', '_message_listener', + '_message_descriptor'] + + def __init__(self, message_listener, message_descriptor, key_checker): + """ + Args: + message_listener: A MessageListener implementation. + The ScalarMap will call this object's Modified() method when it + is modified. + key_checker: A type_checkers.ValueChecker instance to run on keys + inserted into this container. + value_checker: A type_checkers.ValueChecker instance to run on values + inserted into this container. + """ + self._message_listener = message_listener + self._message_descriptor = message_descriptor + self._key_checker = key_checker + self._values = {} + + def __getitem__(self, key): + try: + return self._values[key] + except KeyError: + key = self._key_checker.CheckValue(key) + new_element = self._message_descriptor._concrete_class() + new_element._SetListener(self._message_listener) + self._values[key] = new_element + self._message_listener.Modified() + + return new_element + + def get_or_create(self, key): + """get_or_create() is an alias for getitem (ie. map[key]). + + Args: + key: The key to get or create in the map. + + This is useful in cases where you want to be explicit that the call is + mutating the map. This can avoid lint errors for statements like this + that otherwise would appear to be pointless statements: + + msg.my_map[key] + """ + return self[key] + + # We need to override this explicitly, because our defaultdict-like behavior + # will make the default implementation (from our base class) always insert + # the key. + def get(self, key, default=None): + if key in self: + return self[key] + else: + return default + + def __contains__(self, item): + return item in self._values + + def __setitem__(self, key, value): + raise ValueError('May not set values directly, call my_map[key].foo = 5') + + def __delitem__(self, key): + del self._values[key] + self._message_listener.Modified() + + def __len__(self): + return len(self._values) + + def __iter__(self): + return iter(self._values) + + def __repr__(self): + return repr(self._values) + + def MergeFrom(self, other): + for key in other: + # According to documentation: "When parsing from the wire or when merging, + # if there are duplicate map keys the last key seen is used". + if key in self: + del self[key] + self[key].CopyFrom(other[key]) + # self._message_listener.Modified() not required here, because + # mutations to submessages already propagate. + + def InvalidateIterators(self): + # It appears that the only way to reliably invalidate iterators to + # self._values is to ensure that its size changes. + original = self._values + self._values = original.copy() + original[None] = None + + # This is defined in the abstract base, but we can do it much more cheaply. + def clear(self): + self._values.clear() + self._message_listener.Modified() |