diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /contrib/python/protobuf/py3 | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'contrib/python/protobuf/py3')
67 files changed, 28067 insertions, 0 deletions
diff --git a/contrib/python/protobuf/py3/.yandex_meta/devtools.copyrights.report b/contrib/python/protobuf/py3/.yandex_meta/devtools.copyrights.report new file mode 100644 index 0000000000..90d0362951 --- /dev/null +++ b/contrib/python/protobuf/py3/.yandex_meta/devtools.copyrights.report @@ -0,0 +1,110 @@ +# File format ($ symbol means the beginning of a line): +# +# $ # this message +# $ # ======================= +# $ # comments (all commentaries should starts with some number of spaces and # symbol) +# $ IGNORE_FILES {file1.ext1} {file2.ext2} - (optional) ignore listed files when generating license macro and credits +# $ +# ${action} {license id} {license text hash} +# $BELONGS ./ya/make/file/relative/path/1/ya.make ./ya/make/2/ya.make +# ${all_file_action} filename +# $ # user commentaries (many lines) +# $ generated description - files with this license, license text... (some number of lines that starts with some number of spaces, do not modify) +# ${action} {license spdx} {license text hash} +# $BELONGS ./ya/make/file/relative/path/3/ya.make +# ${all_file_action} filename +# $ # user commentaries +# $ generated description +# $ ... +# +# You can modify action, all_file_action and add commentaries +# Available actions: +# keep - keep license in contrib and use in credits +# skip - skip license +# remove - remove all files with this license +# rename - save license text/links into licenses texts file, but not store SPDX into LINCENSE macro. You should store correct license id into devtools.license.spdx.txt file +# +# {all file action} records will be generated when license text contains filename that exists on filesystem (in contrib directory) +# We suppose that that files can contain some license info +# Available all file actions: +# FILE_IGNORE - ignore file (do nothing) +# FILE_INCLUDE - include all file data into licenses text file +# ======================= + +KEEP COPYRIGHT_SERVICE_LABEL 29be00457f74dcf5f2b494d41d6a6c10 +BELONGS ya.make + License text: + \# Copyright 2007 Google Inc. All Rights Reserved. + Scancode info: + Original SPDX id: COPYRIGHT_SERVICE_LABEL + Score : 100.00 + Match type : COPYRIGHT + Files with this license: + google/protobuf/__init__.py [31:31] + +KEEP COPYRIGHT_SERVICE_LABEL 7ddb2995f48012001146c0eb94d23367 +BELONGS ya.make + License text: + \# Copyright 2008 Google Inc. All rights reserved. + Scancode info: + Original SPDX id: COPYRIGHT_SERVICE_LABEL + Score : 100.00 + Match type : COPYRIGHT + Files with this license: + google/protobuf/__init__.py [2:2] + google/protobuf/descriptor.py [2:2] + google/protobuf/descriptor_database.py [2:2] + google/protobuf/descriptor_pool.py [2:2] + google/protobuf/internal/__init__.py [2:2] + google/protobuf/internal/_parameterized.py [4:4] + google/protobuf/internal/api_implementation.cc [2:2] + google/protobuf/internal/api_implementation.py [2:2] + google/protobuf/internal/containers.py [2:2] + google/protobuf/internal/decoder.py [2:2] + google/protobuf/internal/encoder.py [2:2] + google/protobuf/internal/enum_type_wrapper.py [2:2] + google/protobuf/internal/extension_dict.py [2:2] + google/protobuf/internal/message_listener.py [2:2] + google/protobuf/internal/python_message.py [2:2] + google/protobuf/internal/type_checkers.py [2:2] + google/protobuf/internal/well_known_types.py [2:2] + google/protobuf/internal/wire_format.py [2:2] + google/protobuf/json_format.py [2:2] + google/protobuf/message.py [2:2] + google/protobuf/message_factory.py [2:2] + google/protobuf/proto_api.h [2:2] + google/protobuf/proto_builder.py [2:2] + google/protobuf/pyext/cpp_message.py [2:2] + google/protobuf/pyext/descriptor.cc [2:2] + google/protobuf/pyext/descriptor.h [2:2] + google/protobuf/pyext/descriptor_containers.cc [2:2] + google/protobuf/pyext/descriptor_containers.h [2:2] + google/protobuf/pyext/descriptor_database.cc [2:2] + google/protobuf/pyext/descriptor_database.h [2:2] + google/protobuf/pyext/descriptor_pool.cc [2:2] + google/protobuf/pyext/descriptor_pool.h [2:2] + google/protobuf/pyext/extension_dict.cc [2:2] + google/protobuf/pyext/extension_dict.h [2:2] + google/protobuf/pyext/field.cc [2:2] + google/protobuf/pyext/field.h [2:2] + google/protobuf/pyext/map_container.cc [2:2] + google/protobuf/pyext/map_container.h [2:2] + google/protobuf/pyext/message.cc [2:2] + google/protobuf/pyext/message.h [2:2] + google/protobuf/pyext/message_factory.cc [2:2] + google/protobuf/pyext/message_factory.h [2:2] + google/protobuf/pyext/message_module.cc [2:2] + google/protobuf/pyext/repeated_composite_container.cc [2:2] + google/protobuf/pyext/repeated_composite_container.h [2:2] + google/protobuf/pyext/repeated_scalar_container.cc [2:2] + google/protobuf/pyext/repeated_scalar_container.h [2:2] + google/protobuf/pyext/safe_numerics.h [2:2] + google/protobuf/pyext/scoped_pyobject_ptr.h [2:2] + google/protobuf/pyext/unknown_fields.cc [2:2] + google/protobuf/pyext/unknown_fields.h [2:2] + google/protobuf/reflection.py [2:2] + google/protobuf/service.py [2:2] + google/protobuf/service_reflection.py [2:2] + google/protobuf/symbol_database.py [2:2] + google/protobuf/text_encoding.py [2:2] + google/protobuf/text_format.py [2:2] diff --git a/contrib/python/protobuf/py3/.yandex_meta/devtools.licenses.report b/contrib/python/protobuf/py3/.yandex_meta/devtools.licenses.report new file mode 100644 index 0000000000..008ecfdee6 --- /dev/null +++ b/contrib/python/protobuf/py3/.yandex_meta/devtools.licenses.report @@ -0,0 +1,109 @@ +# File format ($ symbol means the beginning of a line): +# +# $ # this message +# $ # ======================= +# $ # comments (all commentaries should starts with some number of spaces and # symbol) +# $ IGNORE_FILES {file1.ext1} {file2.ext2} - (optional) ignore listed files when generating license macro and credits +# $ +# ${action} {license id} {license text hash} +# $BELONGS ./ya/make/file/relative/path/1/ya.make ./ya/make/2/ya.make +# ${all_file_action} filename +# $ # user commentaries (many lines) +# $ generated description - files with this license, license text... (some number of lines that starts with some number of spaces, do not modify) +# ${action} {license spdx} {license text hash} +# $BELONGS ./ya/make/file/relative/path/3/ya.make +# ${all_file_action} filename +# $ # user commentaries +# $ generated description +# $ ... +# +# You can modify action, all_file_action and add commentaries +# Available actions: +# keep - keep license in contrib and use in credits +# skip - skip license +# remove - remove all files with this license +# rename - save license text/links into licenses texts file, but not store SPDX into LINCENSE macro. You should store correct license id into devtools.license.spdx.txt file +# +# {all file action} records will be generated when license text contains filename that exists on filesystem (in contrib directory) +# We suppose that that files can contain some license info +# Available all file actions: +# FILE_IGNORE - ignore file (do nothing) +# FILE_INCLUDE - include all file data into licenses text file +# ======================= + +KEEP BSD-3-Clause 6aa235708ac9f5dd8e5c6ac415fc5837 +BELONGS ya.make + Note: matched license text is too long. Read it in the source files. + Scancode info: + Original SPDX id: BSD-3-Clause + Score : 100.00 + Match type : TEXT + Links : http://www.opensource.org/licenses/BSD-3-Clause, https://spdx.org/licenses/BSD-3-Clause + Files with this license: + google/protobuf/internal/api_implementation.cc [5:29] + google/protobuf/proto_api.h [5:29] + google/protobuf/pyext/descriptor.cc [5:29] + google/protobuf/pyext/descriptor.h [5:29] + google/protobuf/pyext/descriptor_containers.cc [5:29] + google/protobuf/pyext/descriptor_containers.h [5:29] + google/protobuf/pyext/descriptor_database.cc [5:29] + google/protobuf/pyext/descriptor_database.h [5:29] + google/protobuf/pyext/descriptor_pool.cc [5:29] + google/protobuf/pyext/descriptor_pool.h [5:29] + google/protobuf/pyext/extension_dict.cc [5:29] + google/protobuf/pyext/extension_dict.h [5:29] + google/protobuf/pyext/field.cc [5:29] + google/protobuf/pyext/field.h [5:29] + google/protobuf/pyext/map_container.cc [5:29] + google/protobuf/pyext/map_container.h [5:29] + google/protobuf/pyext/message.cc [5:29] + google/protobuf/pyext/message.h [5:29] + google/protobuf/pyext/message_factory.cc [5:29] + google/protobuf/pyext/message_factory.h [5:29] + google/protobuf/pyext/message_module.cc [5:29] + google/protobuf/pyext/repeated_composite_container.cc [5:29] + google/protobuf/pyext/repeated_composite_container.h [5:29] + google/protobuf/pyext/repeated_scalar_container.cc [5:29] + google/protobuf/pyext/repeated_scalar_container.h [5:29] + google/protobuf/pyext/safe_numerics.h [5:29] + google/protobuf/pyext/scoped_pyobject_ptr.h [5:29] + google/protobuf/pyext/unknown_fields.cc [5:29] + google/protobuf/pyext/unknown_fields.h [5:29] + +KEEP BSD-3-Clause 8aaace038fd54f3a52b1f041f9504709 +BELONGS ya.make + Note: matched license text is too long. Read it in the source files. + Scancode info: + Original SPDX id: BSD-3-Clause + Score : 100.00 + Match type : TEXT + Links : http://www.opensource.org/licenses/BSD-3-Clause, https://spdx.org/licenses/BSD-3-Clause + Files with this license: + google/protobuf/__init__.py [5:29] + google/protobuf/descriptor.py [5:29] + google/protobuf/descriptor_database.py [5:29] + google/protobuf/descriptor_pool.py [5:29] + google/protobuf/internal/__init__.py [5:29] + google/protobuf/internal/_parameterized.py [7:31] + google/protobuf/internal/api_implementation.py [5:29] + google/protobuf/internal/containers.py [5:29] + google/protobuf/internal/decoder.py [5:29] + google/protobuf/internal/encoder.py [5:29] + google/protobuf/internal/enum_type_wrapper.py [5:29] + google/protobuf/internal/extension_dict.py [5:29] + google/protobuf/internal/message_listener.py [5:29] + google/protobuf/internal/python_message.py [5:29] + google/protobuf/internal/type_checkers.py [5:29] + google/protobuf/internal/well_known_types.py [5:29] + google/protobuf/internal/wire_format.py [5:29] + google/protobuf/json_format.py [5:29] + google/protobuf/message.py [5:29] + google/protobuf/message_factory.py [5:29] + google/protobuf/proto_builder.py [5:29] + google/protobuf/pyext/cpp_message.py [5:29] + google/protobuf/reflection.py [5:29] + google/protobuf/service.py [5:29] + google/protobuf/service_reflection.py [5:29] + google/protobuf/symbol_database.py [5:29] + google/protobuf/text_encoding.py [5:29] + google/protobuf/text_format.py [5:29] diff --git a/contrib/python/protobuf/py3/.yandex_meta/licenses.list.txt b/contrib/python/protobuf/py3/.yandex_meta/licenses.list.txt new file mode 100644 index 0000000000..81364a2ff7 --- /dev/null +++ b/contrib/python/protobuf/py3/.yandex_meta/licenses.list.txt @@ -0,0 +1,62 @@ +====================BSD-3-Clause==================== +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +====================BSD-3-Clause==================== +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +====================COPYRIGHT==================== +# Copyright 2007 Google Inc. All Rights Reserved. + + +====================COPYRIGHT==================== +# Copyright 2008 Google Inc. All rights reserved. diff --git a/contrib/python/protobuf/py3/README.md b/contrib/python/protobuf/py3/README.md new file mode 100644 index 0000000000..cb8b7e9892 --- /dev/null +++ b/contrib/python/protobuf/py3/README.md @@ -0,0 +1,132 @@ +Protocol Buffers - Google's data interchange format +=================================================== + +[![Build status](https://storage.googleapis.com/protobuf-kokoro-results/status-badge/linux-python.png)](https://fusion.corp.google.com/projectanalysis/current/KOKORO/prod:protobuf%2Fgithub%2Fmaster%2Fubuntu%2Fpython%2Fcontinuous) [![Build status](https://storage.googleapis.com/protobuf-kokoro-results/status-badge/linux-python_compatibility.png)](https://fusion.corp.google.com/projectanalysis/current/KOKORO/prod:protobuf%2Fgithub%2Fmaster%2Fubuntu%2Fpython_compatibility%2Fcontinuous) [![Build status](https://storage.googleapis.com/protobuf-kokoro-results/status-badge/linux-python_cpp.png)](https://fusion.corp.google.com/projectanalysis/current/KOKORO/prod:protobuf%2Fgithub%2Fmaster%2Fubuntu%2Fpython_cpp%2Fcontinuous) [![Build status](https://storage.googleapis.com/protobuf-kokoro-results/status-badge/macos-python.png)](https://fusion.corp.google.com/projectanalysis/current/KOKORO/prod:protobuf%2Fgithub%2Fmaster%2Fmacos%2Fpython%2Fcontinuous) [![Build status](https://storage.googleapis.com/protobuf-kokoro-results/status-badge/macos-python_cpp.png)](https://fusion.corp.google.com/projectanalysis/current/KOKORO/prod:protobuf%2Fgithub%2Fmaster%2Fmacos%2Fpython_cpp%2Fcontinuous) [![Compat check PyPI](https://python-compatibility-tools.appspot.com/one_badge_image?package=protobuf)](https://python-compatibility-tools.appspot.com/one_badge_target?package=protobuf) + +Copyright 2008 Google Inc. + +This directory contains the Python Protocol Buffers runtime library. + +Normally, this directory comes as part of the protobuf package, available +from: + + https://developers.google.com/protocol-buffers/ + +The complete package includes the C++ source code, which includes the +Protocol Compiler (protoc). If you downloaded this package from PyPI +or some other Python-specific source, you may have received only the +Python part of the code. In this case, you will need to obtain the +Protocol Compiler from some other source before you can use this +package. + +Development Warning +=================== + +The pure python performance is slow. For better performance please +use python c++ implementation. + +Installation +============ + +1) Make sure you have Python 2.7 or newer. If in doubt, run: + + $ python -V + +2) If you do not have setuptools installed, note that it will be + downloaded and installed automatically as soon as you run `setup.py`. + If you would rather install it manually, you may do so by following + the instructions on [this page](https://packaging.python.org/en/latest/installing.html#setup-for-installing-packages). + +3) Build the C++ code, or install a binary distribution of `protoc`. If + you install a binary distribution, make sure that it is the same + version as this package. If in doubt, run: + + $ protoc --version + +4) Build and run the tests: + + $ python setup.py build + $ python setup.py test + + To build, test, and use the C++ implementation, you must first compile + `libprotobuf.so`: + + $ (cd .. && make) + + On OS X: + + If you are running a Homebrew-provided Python, you must make sure another + version of protobuf is not already installed, as Homebrew's Python will + search `/usr/local/lib` for `libprotobuf.so` before it searches + `../src/.libs`. + + You can either unlink Homebrew's protobuf or install the `libprotobuf` you + built earlier: + + $ brew unlink protobuf + + or + + $ (cd .. && make install) + + On other *nix: + + You must make `libprotobuf.so` dynamically available. You can either + install libprotobuf you built earlier, or set `LD_LIBRARY_PATH`: + + $ export LD_LIBRARY_PATH=../src/.libs + + or + + $ (cd .. && make install) + + To build the C++ implementation run: + + $ python setup.py build --cpp_implementation + + Then run the tests like so: + + $ python setup.py test --cpp_implementation + + If some tests fail, this library may not work correctly on your + system. Continue at your own risk. + + Please note that there is a known problem with some versions of + Python on Cygwin which causes the tests to fail after printing the + error: `sem_init: Resource temporarily unavailable`. This appears + to be a [bug either in Cygwin or in + Python](http://www.cygwin.com/ml/cygwin/2005-07/msg01378.html). + + We do not know if or when it might be fixed. We also do not know + how likely it is that this bug will affect users in practice. + +5) Install: + + $ python setup.py install + + or: + + $ (cd .. && make install) + $ python setup.py install --cpp_implementation + + This step may require superuser privileges. + NOTE: To use C++ implementation, you need to export an environment + variable before running your program. See the "C++ Implementation" + section below for more details. + +Usage +===== + +The complete documentation for Protocol Buffers is available via the +web at: + + https://developers.google.com/protocol-buffers/ + +C++ Implementation +================== + +The C++ implementation for Python messages is built as a Python extension to +improve the overall protobuf Python performance. + +To use the C++ implementation, you need to install the C++ protobuf runtime +library, please see instructions in the parent directory. diff --git a/contrib/python/protobuf/py3/google/__init__.py b/contrib/python/protobuf/py3/google/__init__.py new file mode 100644 index 0000000000..5585614122 --- /dev/null +++ b/contrib/python/protobuf/py3/google/__init__.py @@ -0,0 +1,4 @@ +try: + __import__('pkg_resources').declare_namespace(__name__) +except ImportError: + __path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/contrib/python/protobuf/py3/google/protobuf/__init__.py b/contrib/python/protobuf/py3/google/protobuf/__init__.py new file mode 100644 index 0000000000..496df6adaf --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/__init__.py @@ -0,0 +1,33 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Copyright 2007 Google Inc. All Rights Reserved. + +__version__ = '3.17.3' diff --git a/contrib/python/protobuf/py3/google/protobuf/compiler/__init__.py b/contrib/python/protobuf/py3/google/protobuf/compiler/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/compiler/__init__.py diff --git a/contrib/python/protobuf/py3/google/protobuf/descriptor.py b/contrib/python/protobuf/py3/google/protobuf/descriptor.py new file mode 100644 index 0000000000..70fdae16ff --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/descriptor.py @@ -0,0 +1,1183 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Descriptors essentially contain exactly the information found in a .proto +file, in types that make this information accessible in Python. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import threading +import warnings +import six + +from google.protobuf.internal import api_implementation + +_USE_C_DESCRIPTORS = False +if api_implementation.Type() == 'cpp': + # Used by MakeDescriptor in cpp mode + import binascii + import os + from google.protobuf.pyext import _message + _USE_C_DESCRIPTORS = True + + +class Error(Exception): + """Base error for this module.""" + + +class TypeTransformationError(Error): + """Error transforming between python proto type and corresponding C++ type.""" + + +if _USE_C_DESCRIPTORS: + # This metaclass allows to override the behavior of code like + # isinstance(my_descriptor, FieldDescriptor) + # and make it return True when the descriptor is an instance of the extension + # type written in C++. + class DescriptorMetaclass(type): + def __instancecheck__(cls, obj): + if super(DescriptorMetaclass, cls).__instancecheck__(obj): + return True + if isinstance(obj, cls._C_DESCRIPTOR_CLASS): + return True + return False +else: + # The standard metaclass; nothing changes. + DescriptorMetaclass = type + + +class _Lock(object): + """Wrapper class of threading.Lock(), which is allowed by 'with'.""" + + def __new__(cls): + self = object.__new__(cls) + self._lock = threading.Lock() # pylint: disable=protected-access + return self + + def __enter__(self): + self._lock.acquire() + + def __exit__(self, exc_type, exc_value, exc_tb): + self._lock.release() + + +_lock = threading.Lock() + + +def _Deprecated(name): + if _Deprecated.count > 0: + _Deprecated.count -= 1 + warnings.warn( + 'Call to deprecated create function %s(). Note: Create unlinked ' + 'descriptors is going to go away. Please use get/find descriptors from ' + 'generated code or query the descriptor_pool.' + % name, + category=DeprecationWarning, stacklevel=3) + + +# Deprecated warnings will print 100 times at most which should be enough for +# users to notice and do not cause timeout. +_Deprecated.count = 100 + + +_internal_create_key = object() + + +class DescriptorBase(six.with_metaclass(DescriptorMetaclass)): + + """Descriptors base class. + + This class is the base of all descriptor classes. It provides common options + related functionality. + + Attributes: + has_options: True if the descriptor has non-default options. Usually it + is not necessary to read this -- just call GetOptions() which will + happily return the default instance. However, it's sometimes useful + for efficiency, and also useful inside the protobuf implementation to + avoid some bootstrapping issues. + """ + + if _USE_C_DESCRIPTORS: + # The class, or tuple of classes, that are considered as "virtual + # subclasses" of this descriptor class. + _C_DESCRIPTOR_CLASS = () + + def __init__(self, options, serialized_options, options_class_name): + """Initialize the descriptor given its options message and the name of the + class of the options message. The name of the class is required in case + the options message is None and has to be created. + """ + self._options = options + self._options_class_name = options_class_name + self._serialized_options = serialized_options + + # Does this descriptor have non-default options? + self.has_options = (options is not None) or (serialized_options is not None) + + def _SetOptions(self, options, options_class_name): + """Sets the descriptor's options + + This function is used in generated proto2 files to update descriptor + options. It must not be used outside proto2. + """ + self._options = options + self._options_class_name = options_class_name + + # Does this descriptor have non-default options? + self.has_options = options is not None + + def GetOptions(self): + """Retrieves descriptor options. + + This method returns the options set or creates the default options for the + descriptor. + """ + if self._options: + return self._options + + from google.protobuf import descriptor_pb2 + try: + options_class = getattr(descriptor_pb2, + self._options_class_name) + except AttributeError: + raise RuntimeError('Unknown options class name %s!' % + (self._options_class_name)) + + with _lock: + if self._serialized_options is None: + self._options = options_class() + else: + self._options = _ParseOptions(options_class(), + self._serialized_options) + + return self._options + + +class _NestedDescriptorBase(DescriptorBase): + """Common class for descriptors that can be nested.""" + + def __init__(self, options, options_class_name, name, full_name, + file, containing_type, serialized_start=None, + serialized_end=None, serialized_options=None): + """Constructor. + + Args: + options: Protocol message options or None + to use default message options. + options_class_name (str): The class name of the above options. + name (str): Name of this protocol message type. + full_name (str): Fully-qualified name of this protocol message type, + which will include protocol "package" name and the name of any + enclosing types. + file (FileDescriptor): Reference to file info. + containing_type: if provided, this is a nested descriptor, with this + descriptor as parent, otherwise None. + serialized_start: The start index (inclusive) in block in the + file.serialized_pb that describes this descriptor. + serialized_end: The end index (exclusive) in block in the + file.serialized_pb that describes this descriptor. + serialized_options: Protocol message serialized options or None. + """ + super(_NestedDescriptorBase, self).__init__( + options, serialized_options, options_class_name) + + self.name = name + # TODO(falk): Add function to calculate full_name instead of having it in + # memory? + self.full_name = full_name + self.file = file + self.containing_type = containing_type + + self._serialized_start = serialized_start + self._serialized_end = serialized_end + + def CopyToProto(self, proto): + """Copies this to the matching proto in descriptor_pb2. + + Args: + proto: An empty proto instance from descriptor_pb2. + + Raises: + Error: If self couldn't be serialized, due to to few constructor + arguments. + """ + if (self.file is not None and + self._serialized_start is not None and + self._serialized_end is not None): + proto.ParseFromString(self.file.serialized_pb[ + self._serialized_start:self._serialized_end]) + else: + raise Error('Descriptor does not contain serialization.') + + +class Descriptor(_NestedDescriptorBase): + + """Descriptor for a protocol message type. + + Attributes: + name (str): Name of this protocol message type. + full_name (str): Fully-qualified name of this protocol message type, + which will include protocol "package" name and the name of any + enclosing types. + containing_type (Descriptor): Reference to the descriptor of the type + containing us, or None if this is top-level. + fields (list[FieldDescriptor]): Field descriptors for all fields in + this type. + fields_by_number (dict(int, FieldDescriptor)): Same + :class:`FieldDescriptor` objects as in :attr:`fields`, but indexed + by "number" attribute in each FieldDescriptor. + fields_by_name (dict(str, FieldDescriptor)): Same + :class:`FieldDescriptor` objects as in :attr:`fields`, but indexed by + "name" attribute in each :class:`FieldDescriptor`. + nested_types (list[Descriptor]): Descriptor references + for all protocol message types nested within this one. + nested_types_by_name (dict(str, Descriptor)): Same Descriptor + objects as in :attr:`nested_types`, but indexed by "name" attribute + in each Descriptor. + enum_types (list[EnumDescriptor]): :class:`EnumDescriptor` references + for all enums contained within this type. + enum_types_by_name (dict(str, EnumDescriptor)): Same + :class:`EnumDescriptor` objects as in :attr:`enum_types`, but + indexed by "name" attribute in each EnumDescriptor. + enum_values_by_name (dict(str, EnumValueDescriptor)): Dict mapping + from enum value name to :class:`EnumValueDescriptor` for that value. + extensions (list[FieldDescriptor]): All extensions defined directly + within this message type (NOT within a nested type). + extensions_by_name (dict(str, FieldDescriptor)): Same FieldDescriptor + objects as :attr:`extensions`, but indexed by "name" attribute of each + FieldDescriptor. + is_extendable (bool): Does this type define any extension ranges? + oneofs (list[OneofDescriptor]): The list of descriptors for oneof fields + in this message. + oneofs_by_name (dict(str, OneofDescriptor)): Same objects as in + :attr:`oneofs`, but indexed by "name" attribute. + file (FileDescriptor): Reference to file descriptor. + + """ + + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.Descriptor + + def __new__( + cls, + name=None, + full_name=None, + filename=None, + containing_type=None, + fields=None, + nested_types=None, + enum_types=None, + extensions=None, + options=None, + serialized_options=None, + is_extendable=True, + extension_ranges=None, + oneofs=None, + file=None, # pylint: disable=redefined-builtin + serialized_start=None, + serialized_end=None, + syntax=None, + create_key=None): + _message.Message._CheckCalledFromGeneratedFile() + return _message.default_pool.FindMessageTypeByName(full_name) + + # NOTE(tmarek): The file argument redefining a builtin is nothing we can + # fix right now since we don't know how many clients already rely on the + # name of the argument. + def __init__(self, name, full_name, filename, containing_type, fields, + nested_types, enum_types, extensions, options=None, + serialized_options=None, + is_extendable=True, extension_ranges=None, oneofs=None, + file=None, serialized_start=None, serialized_end=None, # pylint: disable=redefined-builtin + syntax=None, create_key=None): + """Arguments to __init__() are as described in the description + of Descriptor fields above. + + Note that filename is an obsolete argument, that is not used anymore. + Please use file.name to access this as an attribute. + """ + if create_key is not _internal_create_key: + _Deprecated('Descriptor') + + super(Descriptor, self).__init__( + options, 'MessageOptions', name, full_name, file, + containing_type, serialized_start=serialized_start, + serialized_end=serialized_end, serialized_options=serialized_options) + + # We have fields in addition to fields_by_name and fields_by_number, + # so that: + # 1. Clients can index fields by "order in which they're listed." + # 2. Clients can easily iterate over all fields with the terse + # syntax: for f in descriptor.fields: ... + self.fields = fields + for field in self.fields: + field.containing_type = self + self.fields_by_number = dict((f.number, f) for f in fields) + self.fields_by_name = dict((f.name, f) for f in fields) + self._fields_by_camelcase_name = None + + self.nested_types = nested_types + for nested_type in nested_types: + nested_type.containing_type = self + self.nested_types_by_name = dict((t.name, t) for t in nested_types) + + self.enum_types = enum_types + for enum_type in self.enum_types: + enum_type.containing_type = self + self.enum_types_by_name = dict((t.name, t) for t in enum_types) + self.enum_values_by_name = dict( + (v.name, v) for t in enum_types for v in t.values) + + self.extensions = extensions + for extension in self.extensions: + extension.extension_scope = self + self.extensions_by_name = dict((f.name, f) for f in extensions) + self.is_extendable = is_extendable + self.extension_ranges = extension_ranges + self.oneofs = oneofs if oneofs is not None else [] + self.oneofs_by_name = dict((o.name, o) for o in self.oneofs) + for oneof in self.oneofs: + oneof.containing_type = self + self.syntax = syntax or "proto2" + + @property + def fields_by_camelcase_name(self): + """Same FieldDescriptor objects as in :attr:`fields`, but indexed by + :attr:`FieldDescriptor.camelcase_name`. + """ + if self._fields_by_camelcase_name is None: + self._fields_by_camelcase_name = dict( + (f.camelcase_name, f) for f in self.fields) + return self._fields_by_camelcase_name + + def EnumValueName(self, enum, value): + """Returns the string name of an enum value. + + This is just a small helper method to simplify a common operation. + + Args: + enum: string name of the Enum. + value: int, value of the enum. + + Returns: + string name of the enum value. + + Raises: + KeyError if either the Enum doesn't exist or the value is not a valid + value for the enum. + """ + return self.enum_types_by_name[enum].values_by_number[value].name + + def CopyToProto(self, proto): + """Copies this to a descriptor_pb2.DescriptorProto. + + Args: + proto: An empty descriptor_pb2.DescriptorProto. + """ + # This function is overridden to give a better doc comment. + super(Descriptor, self).CopyToProto(proto) + + +# TODO(robinson): We should have aggressive checking here, +# for example: +# * If you specify a repeated field, you should not be allowed +# to specify a default value. +# * [Other examples here as needed]. +# +# TODO(robinson): for this and other *Descriptor classes, we +# might also want to lock things down aggressively (e.g., +# prevent clients from setting the attributes). Having +# stronger invariants here in general will reduce the number +# of runtime checks we must do in reflection.py... +class FieldDescriptor(DescriptorBase): + + """Descriptor for a single field in a .proto file. + + Attributes: + name (str): Name of this field, exactly as it appears in .proto. + full_name (str): Name of this field, including containing scope. This is + particularly relevant for extensions. + index (int): Dense, 0-indexed index giving the order that this + field textually appears within its message in the .proto file. + number (int): Tag number declared for this field in the .proto file. + + type (int): (One of the TYPE_* constants below) Declared type. + cpp_type (int): (One of the CPPTYPE_* constants below) C++ type used to + represent this field. + + label (int): (One of the LABEL_* constants below) Tells whether this + field is optional, required, or repeated. + has_default_value (bool): True if this field has a default value defined, + otherwise false. + default_value (Varies): Default value of this field. Only + meaningful for non-repeated scalar fields. Repeated fields + should always set this to [], and non-repeated composite + fields should always set this to None. + + containing_type (Descriptor): Descriptor of the protocol message + type that contains this field. Set by the Descriptor constructor + if we're passed into one. + Somewhat confusingly, for extension fields, this is the + descriptor of the EXTENDED message, not the descriptor + of the message containing this field. (See is_extension and + extension_scope below). + message_type (Descriptor): If a composite field, a descriptor + of the message type contained in this field. Otherwise, this is None. + enum_type (EnumDescriptor): If this field contains an enum, a + descriptor of that enum. Otherwise, this is None. + + is_extension: True iff this describes an extension field. + extension_scope (Descriptor): Only meaningful if is_extension is True. + Gives the message that immediately contains this extension field. + Will be None iff we're a top-level (file-level) extension field. + + options (descriptor_pb2.FieldOptions): Protocol message field options or + None to use default field options. + + containing_oneof (OneofDescriptor): If the field is a member of a oneof + union, contains its descriptor. Otherwise, None. + + file (FileDescriptor): Reference to file descriptor. + """ + + # Must be consistent with C++ FieldDescriptor::Type enum in + # descriptor.h. + # + # TODO(robinson): Find a way to eliminate this repetition. + TYPE_DOUBLE = 1 + TYPE_FLOAT = 2 + TYPE_INT64 = 3 + TYPE_UINT64 = 4 + TYPE_INT32 = 5 + TYPE_FIXED64 = 6 + TYPE_FIXED32 = 7 + TYPE_BOOL = 8 + TYPE_STRING = 9 + TYPE_GROUP = 10 + TYPE_MESSAGE = 11 + TYPE_BYTES = 12 + TYPE_UINT32 = 13 + TYPE_ENUM = 14 + TYPE_SFIXED32 = 15 + TYPE_SFIXED64 = 16 + TYPE_SINT32 = 17 + TYPE_SINT64 = 18 + MAX_TYPE = 18 + + # Must be consistent with C++ FieldDescriptor::CppType enum in + # descriptor.h. + # + # TODO(robinson): Find a way to eliminate this repetition. + CPPTYPE_INT32 = 1 + CPPTYPE_INT64 = 2 + CPPTYPE_UINT32 = 3 + CPPTYPE_UINT64 = 4 + CPPTYPE_DOUBLE = 5 + CPPTYPE_FLOAT = 6 + CPPTYPE_BOOL = 7 + CPPTYPE_ENUM = 8 + CPPTYPE_STRING = 9 + CPPTYPE_MESSAGE = 10 + MAX_CPPTYPE = 10 + + _PYTHON_TO_CPP_PROTO_TYPE_MAP = { + TYPE_DOUBLE: CPPTYPE_DOUBLE, + TYPE_FLOAT: CPPTYPE_FLOAT, + TYPE_ENUM: CPPTYPE_ENUM, + TYPE_INT64: CPPTYPE_INT64, + TYPE_SINT64: CPPTYPE_INT64, + TYPE_SFIXED64: CPPTYPE_INT64, + TYPE_UINT64: CPPTYPE_UINT64, + TYPE_FIXED64: CPPTYPE_UINT64, + TYPE_INT32: CPPTYPE_INT32, + TYPE_SFIXED32: CPPTYPE_INT32, + TYPE_SINT32: CPPTYPE_INT32, + TYPE_UINT32: CPPTYPE_UINT32, + TYPE_FIXED32: CPPTYPE_UINT32, + TYPE_BYTES: CPPTYPE_STRING, + TYPE_STRING: CPPTYPE_STRING, + TYPE_BOOL: CPPTYPE_BOOL, + TYPE_MESSAGE: CPPTYPE_MESSAGE, + TYPE_GROUP: CPPTYPE_MESSAGE + } + + # Must be consistent with C++ FieldDescriptor::Label enum in + # descriptor.h. + # + # TODO(robinson): Find a way to eliminate this repetition. + LABEL_OPTIONAL = 1 + LABEL_REQUIRED = 2 + LABEL_REPEATED = 3 + MAX_LABEL = 3 + + # Must be consistent with C++ constants kMaxNumber, kFirstReservedNumber, + # and kLastReservedNumber in descriptor.h + MAX_FIELD_NUMBER = (1 << 29) - 1 + FIRST_RESERVED_FIELD_NUMBER = 19000 + LAST_RESERVED_FIELD_NUMBER = 19999 + + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.FieldDescriptor + + def __new__(cls, name, full_name, index, number, type, cpp_type, label, + default_value, message_type, enum_type, containing_type, + is_extension, extension_scope, options=None, + serialized_options=None, + has_default_value=True, containing_oneof=None, json_name=None, + file=None, create_key=None): # pylint: disable=redefined-builtin + _message.Message._CheckCalledFromGeneratedFile() + if is_extension: + return _message.default_pool.FindExtensionByName(full_name) + else: + return _message.default_pool.FindFieldByName(full_name) + + def __init__(self, name, full_name, index, number, type, cpp_type, label, + default_value, message_type, enum_type, containing_type, + is_extension, extension_scope, options=None, + serialized_options=None, + has_default_value=True, containing_oneof=None, json_name=None, + file=None, create_key=None): # pylint: disable=redefined-builtin + """The arguments are as described in the description of FieldDescriptor + attributes above. + + Note that containing_type may be None, and may be set later if necessary + (to deal with circular references between message types, for example). + Likewise for extension_scope. + """ + if create_key is not _internal_create_key: + _Deprecated('FieldDescriptor') + + super(FieldDescriptor, self).__init__( + options, serialized_options, 'FieldOptions') + self.name = name + self.full_name = full_name + self.file = file + self._camelcase_name = None + if json_name is None: + self.json_name = _ToJsonName(name) + else: + self.json_name = json_name + self.index = index + self.number = number + self.type = type + self.cpp_type = cpp_type + self.label = label + self.has_default_value = has_default_value + self.default_value = default_value + self.containing_type = containing_type + self.message_type = message_type + self.enum_type = enum_type + self.is_extension = is_extension + self.extension_scope = extension_scope + self.containing_oneof = containing_oneof + if api_implementation.Type() == 'cpp': + if is_extension: + self._cdescriptor = _message.default_pool.FindExtensionByName(full_name) + else: + self._cdescriptor = _message.default_pool.FindFieldByName(full_name) + else: + self._cdescriptor = None + + @property + def camelcase_name(self): + """Camelcase name of this field. + + Returns: + str: the name in CamelCase. + """ + if self._camelcase_name is None: + self._camelcase_name = _ToCamelCase(self.name) + return self._camelcase_name + + @staticmethod + def ProtoTypeToCppProtoType(proto_type): + """Converts from a Python proto type to a C++ Proto Type. + + The Python ProtocolBuffer classes specify both the 'Python' datatype and the + 'C++' datatype - and they're not the same. This helper method should + translate from one to another. + + Args: + proto_type: the Python proto type (descriptor.FieldDescriptor.TYPE_*) + Returns: + int: descriptor.FieldDescriptor.CPPTYPE_*, the C++ type. + Raises: + TypeTransformationError: when the Python proto type isn't known. + """ + try: + return FieldDescriptor._PYTHON_TO_CPP_PROTO_TYPE_MAP[proto_type] + except KeyError: + raise TypeTransformationError('Unknown proto_type: %s' % proto_type) + + +class EnumDescriptor(_NestedDescriptorBase): + + """Descriptor for an enum defined in a .proto file. + + Attributes: + name (str): Name of the enum type. + full_name (str): Full name of the type, including package name + and any enclosing type(s). + + values (list[EnumValueDescriptors]): List of the values + in this enum. + values_by_name (dict(str, EnumValueDescriptor)): Same as :attr:`values`, + but indexed by the "name" field of each EnumValueDescriptor. + values_by_number (dict(int, EnumValueDescriptor)): Same as :attr:`values`, + but indexed by the "number" field of each EnumValueDescriptor. + containing_type (Descriptor): Descriptor of the immediate containing + type of this enum, or None if this is an enum defined at the + top level in a .proto file. Set by Descriptor's constructor + if we're passed into one. + file (FileDescriptor): Reference to file descriptor. + options (descriptor_pb2.EnumOptions): Enum options message or + None to use default enum options. + """ + + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.EnumDescriptor + + def __new__(cls, name, full_name, filename, values, + containing_type=None, options=None, + serialized_options=None, file=None, # pylint: disable=redefined-builtin + serialized_start=None, serialized_end=None, create_key=None): + _message.Message._CheckCalledFromGeneratedFile() + return _message.default_pool.FindEnumTypeByName(full_name) + + def __init__(self, name, full_name, filename, values, + containing_type=None, options=None, + serialized_options=None, file=None, # pylint: disable=redefined-builtin + serialized_start=None, serialized_end=None, create_key=None): + """Arguments are as described in the attribute description above. + + Note that filename is an obsolete argument, that is not used anymore. + Please use file.name to access this as an attribute. + """ + if create_key is not _internal_create_key: + _Deprecated('EnumDescriptor') + + super(EnumDescriptor, self).__init__( + options, 'EnumOptions', name, full_name, file, + containing_type, serialized_start=serialized_start, + serialized_end=serialized_end, serialized_options=serialized_options) + + self.values = values + for value in self.values: + value.type = self + self.values_by_name = dict((v.name, v) for v in values) + # Values are reversed to ensure that the first alias is retained. + self.values_by_number = dict((v.number, v) for v in reversed(values)) + + def CopyToProto(self, proto): + """Copies this to a descriptor_pb2.EnumDescriptorProto. + + Args: + proto (descriptor_pb2.EnumDescriptorProto): An empty descriptor proto. + """ + # This function is overridden to give a better doc comment. + super(EnumDescriptor, self).CopyToProto(proto) + + +class EnumValueDescriptor(DescriptorBase): + + """Descriptor for a single value within an enum. + + Attributes: + name (str): Name of this value. + index (int): Dense, 0-indexed index giving the order that this + value appears textually within its enum in the .proto file. + number (int): Actual number assigned to this enum value. + type (EnumDescriptor): :class:`EnumDescriptor` to which this value + belongs. Set by :class:`EnumDescriptor`'s constructor if we're + passed into one. + options (descriptor_pb2.EnumValueOptions): Enum value options message or + None to use default enum value options options. + """ + + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.EnumValueDescriptor + + def __new__(cls, name, index, number, + type=None, # pylint: disable=redefined-builtin + options=None, serialized_options=None, create_key=None): + _message.Message._CheckCalledFromGeneratedFile() + # There is no way we can build a complete EnumValueDescriptor with the + # given parameters (the name of the Enum is not known, for example). + # Fortunately generated files just pass it to the EnumDescriptor() + # constructor, which will ignore it, so returning None is good enough. + return None + + def __init__(self, name, index, number, + type=None, # pylint: disable=redefined-builtin + options=None, serialized_options=None, create_key=None): + """Arguments are as described in the attribute description above.""" + if create_key is not _internal_create_key: + _Deprecated('EnumValueDescriptor') + + super(EnumValueDescriptor, self).__init__( + options, serialized_options, 'EnumValueOptions') + self.name = name + self.index = index + self.number = number + self.type = type + + +class OneofDescriptor(DescriptorBase): + """Descriptor for a oneof field. + + Attributes: + name (str): Name of the oneof field. + full_name (str): Full name of the oneof field, including package name. + index (int): 0-based index giving the order of the oneof field inside + its containing type. + containing_type (Descriptor): :class:`Descriptor` of the protocol message + type that contains this field. Set by the :class:`Descriptor` constructor + if we're passed into one. + fields (list[FieldDescriptor]): The list of field descriptors this + oneof can contain. + """ + + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.OneofDescriptor + + def __new__( + cls, name, full_name, index, containing_type, fields, options=None, + serialized_options=None, create_key=None): + _message.Message._CheckCalledFromGeneratedFile() + return _message.default_pool.FindOneofByName(full_name) + + def __init__( + self, name, full_name, index, containing_type, fields, options=None, + serialized_options=None, create_key=None): + """Arguments are as described in the attribute description above.""" + if create_key is not _internal_create_key: + _Deprecated('OneofDescriptor') + + super(OneofDescriptor, self).__init__( + options, serialized_options, 'OneofOptions') + self.name = name + self.full_name = full_name + self.index = index + self.containing_type = containing_type + self.fields = fields + + +class ServiceDescriptor(_NestedDescriptorBase): + + """Descriptor for a service. + + Attributes: + name (str): Name of the service. + full_name (str): Full name of the service, including package name. + index (int): 0-indexed index giving the order that this services + definition appears within the .proto file. + methods (list[MethodDescriptor]): List of methods provided by this + service. + methods_by_name (dict(str, MethodDescriptor)): Same + :class:`MethodDescriptor` objects as in :attr:`methods_by_name`, but + indexed by "name" attribute in each :class:`MethodDescriptor`. + options (descriptor_pb2.ServiceOptions): Service options message or + None to use default service options. + file (FileDescriptor): Reference to file info. + """ + + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.ServiceDescriptor + + def __new__( + cls, + name=None, + full_name=None, + index=None, + methods=None, + options=None, + serialized_options=None, + file=None, # pylint: disable=redefined-builtin + serialized_start=None, + serialized_end=None, + create_key=None): + _message.Message._CheckCalledFromGeneratedFile() # pylint: disable=protected-access + return _message.default_pool.FindServiceByName(full_name) + + def __init__(self, name, full_name, index, methods, options=None, + serialized_options=None, file=None, # pylint: disable=redefined-builtin + serialized_start=None, serialized_end=None, create_key=None): + if create_key is not _internal_create_key: + _Deprecated('ServiceDescriptor') + + super(ServiceDescriptor, self).__init__( + options, 'ServiceOptions', name, full_name, file, + None, serialized_start=serialized_start, + serialized_end=serialized_end, serialized_options=serialized_options) + self.index = index + self.methods = methods + self.methods_by_name = dict((m.name, m) for m in methods) + # Set the containing service for each method in this service. + for method in self.methods: + method.containing_service = self + + def FindMethodByName(self, name): + """Searches for the specified method, and returns its descriptor. + + Args: + name (str): Name of the method. + Returns: + MethodDescriptor or None: the descriptor for the requested method, if + found. + """ + return self.methods_by_name.get(name, None) + + def CopyToProto(self, proto): + """Copies this to a descriptor_pb2.ServiceDescriptorProto. + + Args: + proto (descriptor_pb2.ServiceDescriptorProto): An empty descriptor proto. + """ + # This function is overridden to give a better doc comment. + super(ServiceDescriptor, self).CopyToProto(proto) + + +class MethodDescriptor(DescriptorBase): + + """Descriptor for a method in a service. + + Attributes: + name (str): Name of the method within the service. + full_name (str): Full name of method. + index (int): 0-indexed index of the method inside the service. + containing_service (ServiceDescriptor): The service that contains this + method. + input_type (Descriptor): The descriptor of the message that this method + accepts. + output_type (Descriptor): The descriptor of the message that this method + returns. + options (descriptor_pb2.MethodOptions or None): Method options message, or + None to use default method options. + """ + + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.MethodDescriptor + + def __new__(cls, name, full_name, index, containing_service, + input_type, output_type, options=None, serialized_options=None, + create_key=None): + _message.Message._CheckCalledFromGeneratedFile() # pylint: disable=protected-access + return _message.default_pool.FindMethodByName(full_name) + + def __init__(self, name, full_name, index, containing_service, + input_type, output_type, options=None, serialized_options=None, + create_key=None): + """The arguments are as described in the description of MethodDescriptor + attributes above. + + Note that containing_service may be None, and may be set later if necessary. + """ + if create_key is not _internal_create_key: + _Deprecated('MethodDescriptor') + + super(MethodDescriptor, self).__init__( + options, serialized_options, 'MethodOptions') + self.name = name + self.full_name = full_name + self.index = index + self.containing_service = containing_service + self.input_type = input_type + self.output_type = output_type + + def CopyToProto(self, proto): + """Copies this to a descriptor_pb2.MethodDescriptorProto. + + Args: + proto (descriptor_pb2.MethodDescriptorProto): An empty descriptor proto. + + Raises: + Error: If self couldn't be serialized, due to too few constructor + arguments. + """ + if self.containing_service is not None: + from google.protobuf import descriptor_pb2 + service_proto = descriptor_pb2.ServiceDescriptorProto() + self.containing_service.CopyToProto(service_proto) + proto.CopyFrom(service_proto.method[self.index]) + else: + raise Error('Descriptor does not contain a service.') + + +class FileDescriptor(DescriptorBase): + """Descriptor for a file. Mimics the descriptor_pb2.FileDescriptorProto. + + Note that :attr:`enum_types_by_name`, :attr:`extensions_by_name`, and + :attr:`dependencies` fields are only set by the + :py:mod:`google.protobuf.message_factory` module, and not by the generated + proto code. + + Attributes: + name (str): Name of file, relative to root of source tree. + package (str): Name of the package + syntax (str): string indicating syntax of the file (can be "proto2" or + "proto3") + serialized_pb (bytes): Byte string of serialized + :class:`descriptor_pb2.FileDescriptorProto`. + dependencies (list[FileDescriptor]): List of other :class:`FileDescriptor` + objects this :class:`FileDescriptor` depends on. + public_dependencies (list[FileDescriptor]): A subset of + :attr:`dependencies`, which were declared as "public". + message_types_by_name (dict(str, Descriptor)): Mapping from message names + to their :class:`Desctiptor`. + enum_types_by_name (dict(str, EnumDescriptor)): Mapping from enum names to + their :class:`EnumDescriptor`. + extensions_by_name (dict(str, FieldDescriptor)): Mapping from extension + names declared at file scope to their :class:`FieldDescriptor`. + services_by_name (dict(str, ServiceDescriptor)): Mapping from services' + names to their :class:`ServiceDescriptor`. + pool (DescriptorPool): The pool this descriptor belongs to. When not + passed to the constructor, the global default pool is used. + """ + + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.FileDescriptor + + def __new__(cls, name, package, options=None, + serialized_options=None, serialized_pb=None, + dependencies=None, public_dependencies=None, + syntax=None, pool=None, create_key=None): + # FileDescriptor() is called from various places, not only from generated + # files, to register dynamic proto files and messages. + # pylint: disable=g-explicit-bool-comparison + if serialized_pb == b'': + # Cpp generated code must be linked in if serialized_pb is '' + try: + return _message.default_pool.FindFileByName(name) + except KeyError: + raise RuntimeError('Please link in cpp generated lib for %s' % (name)) + elif serialized_pb: + return _message.default_pool.AddSerializedFile(serialized_pb) + else: + return super(FileDescriptor, cls).__new__(cls) + + def __init__(self, name, package, options=None, + serialized_options=None, serialized_pb=None, + dependencies=None, public_dependencies=None, + syntax=None, pool=None, create_key=None): + """Constructor.""" + if create_key is not _internal_create_key: + _Deprecated('FileDescriptor') + + super(FileDescriptor, self).__init__( + options, serialized_options, 'FileOptions') + + if pool is None: + from google.protobuf import descriptor_pool + pool = descriptor_pool.Default() + self.pool = pool + self.message_types_by_name = {} + self.name = name + self.package = package + self.syntax = syntax or "proto2" + self.serialized_pb = serialized_pb + + self.enum_types_by_name = {} + self.extensions_by_name = {} + self.services_by_name = {} + self.dependencies = (dependencies or []) + self.public_dependencies = (public_dependencies or []) + + def CopyToProto(self, proto): + """Copies this to a descriptor_pb2.FileDescriptorProto. + + Args: + proto: An empty descriptor_pb2.FileDescriptorProto. + """ + proto.ParseFromString(self.serialized_pb) + + +def _ParseOptions(message, string): + """Parses serialized options. + + This helper function is used to parse serialized options in generated + proto2 files. It must not be used outside proto2. + """ + message.ParseFromString(string) + return message + + +def _ToCamelCase(name): + """Converts name to camel-case and returns it.""" + capitalize_next = False + result = [] + + for c in name: + if c == '_': + if result: + capitalize_next = True + elif capitalize_next: + result.append(c.upper()) + capitalize_next = False + else: + result += c + + # Lower-case the first letter. + if result and result[0].isupper(): + result[0] = result[0].lower() + return ''.join(result) + + +def _OptionsOrNone(descriptor_proto): + """Returns the value of the field `options`, or None if it is not set.""" + if descriptor_proto.HasField('options'): + return descriptor_proto.options + else: + return None + + +def _ToJsonName(name): + """Converts name to Json name and returns it.""" + capitalize_next = False + result = [] + + for c in name: + if c == '_': + capitalize_next = True + elif capitalize_next: + result.append(c.upper()) + capitalize_next = False + else: + result += c + + return ''.join(result) + + +def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True, + syntax=None): + """Make a protobuf Descriptor given a DescriptorProto protobuf. + + Handles nested descriptors. Note that this is limited to the scope of defining + a message inside of another message. Composite fields can currently only be + resolved if the message is defined in the same scope as the field. + + Args: + desc_proto: The descriptor_pb2.DescriptorProto protobuf message. + package: Optional package name for the new message Descriptor (string). + build_file_if_cpp: Update the C++ descriptor pool if api matches. + Set to False on recursion, so no duplicates are created. + syntax: The syntax/semantics that should be used. Set to "proto3" to get + proto3 field presence semantics. + Returns: + A Descriptor for protobuf messages. + """ + if api_implementation.Type() == 'cpp' and build_file_if_cpp: + # The C++ implementation requires all descriptors to be backed by the same + # definition in the C++ descriptor pool. To do this, we build a + # FileDescriptorProto with the same definition as this descriptor and build + # it into the pool. + from google.protobuf import descriptor_pb2 + file_descriptor_proto = descriptor_pb2.FileDescriptorProto() + file_descriptor_proto.message_type.add().MergeFrom(desc_proto) + + # Generate a random name for this proto file to prevent conflicts with any + # imported ones. We need to specify a file name so the descriptor pool + # accepts our FileDescriptorProto, but it is not important what that file + # name is actually set to. + proto_name = binascii.hexlify(os.urandom(16)).decode('ascii') + + if package: + file_descriptor_proto.name = os.path.join(package.replace('.', '/'), + proto_name + '.proto') + file_descriptor_proto.package = package + else: + file_descriptor_proto.name = proto_name + '.proto' + + _message.default_pool.Add(file_descriptor_proto) + result = _message.default_pool.FindFileByName(file_descriptor_proto.name) + + if _USE_C_DESCRIPTORS: + return result.message_types_by_name[desc_proto.name] + + full_message_name = [desc_proto.name] + if package: full_message_name.insert(0, package) + + # Create Descriptors for enum types + enum_types = {} + for enum_proto in desc_proto.enum_type: + full_name = '.'.join(full_message_name + [enum_proto.name]) + enum_desc = EnumDescriptor( + enum_proto.name, full_name, None, [ + EnumValueDescriptor(enum_val.name, ii, enum_val.number, + create_key=_internal_create_key) + for ii, enum_val in enumerate(enum_proto.value)], + create_key=_internal_create_key) + enum_types[full_name] = enum_desc + + # Create Descriptors for nested types + nested_types = {} + for nested_proto in desc_proto.nested_type: + full_name = '.'.join(full_message_name + [nested_proto.name]) + # Nested types are just those defined inside of the message, not all types + # used by fields in the message, so no loops are possible here. + nested_desc = MakeDescriptor(nested_proto, + package='.'.join(full_message_name), + build_file_if_cpp=False, + syntax=syntax) + nested_types[full_name] = nested_desc + + fields = [] + for field_proto in desc_proto.field: + full_name = '.'.join(full_message_name + [field_proto.name]) + enum_desc = None + nested_desc = None + if field_proto.json_name: + json_name = field_proto.json_name + else: + json_name = None + if field_proto.HasField('type_name'): + type_name = field_proto.type_name + full_type_name = '.'.join(full_message_name + + [type_name[type_name.rfind('.')+1:]]) + if full_type_name in nested_types: + nested_desc = nested_types[full_type_name] + elif full_type_name in enum_types: + enum_desc = enum_types[full_type_name] + # Else type_name references a non-local type, which isn't implemented + field = FieldDescriptor( + field_proto.name, full_name, field_proto.number - 1, + field_proto.number, field_proto.type, + FieldDescriptor.ProtoTypeToCppProtoType(field_proto.type), + field_proto.label, None, nested_desc, enum_desc, None, False, None, + options=_OptionsOrNone(field_proto), has_default_value=False, + json_name=json_name, create_key=_internal_create_key) + fields.append(field) + + desc_name = '.'.join(full_message_name) + return Descriptor(desc_proto.name, desc_name, None, None, fields, + list(nested_types.values()), list(enum_types.values()), [], + options=_OptionsOrNone(desc_proto), + create_key=_internal_create_key) diff --git a/contrib/python/protobuf/py3/google/protobuf/descriptor_database.py b/contrib/python/protobuf/py3/google/protobuf/descriptor_database.py new file mode 100644 index 0000000000..073eddc711 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/descriptor_database.py @@ -0,0 +1,177 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Provides a container for DescriptorProtos.""" + +__author__ = 'matthewtoia@google.com (Matt Toia)' + +import warnings + + +class Error(Exception): + pass + + +class DescriptorDatabaseConflictingDefinitionError(Error): + """Raised when a proto is added with the same name & different descriptor.""" + + +class DescriptorDatabase(object): + """A container accepting FileDescriptorProtos and maps DescriptorProtos.""" + + def __init__(self): + self._file_desc_protos_by_file = {} + self._file_desc_protos_by_symbol = {} + + def Add(self, file_desc_proto): + """Adds the FileDescriptorProto and its types to this database. + + Args: + file_desc_proto: The FileDescriptorProto to add. + Raises: + DescriptorDatabaseConflictingDefinitionError: if an attempt is made to + add a proto with the same name but different definition than an + existing proto in the database. + """ + proto_name = file_desc_proto.name + if proto_name not in self._file_desc_protos_by_file: + self._file_desc_protos_by_file[proto_name] = file_desc_proto + elif self._file_desc_protos_by_file[proto_name] != file_desc_proto: + raise DescriptorDatabaseConflictingDefinitionError( + '%s already added, but with different descriptor.' % proto_name) + else: + return + + # Add all the top-level descriptors to the index. + package = file_desc_proto.package + for message in file_desc_proto.message_type: + for name in _ExtractSymbols(message, package): + self._AddSymbol(name, file_desc_proto) + for enum in file_desc_proto.enum_type: + self._AddSymbol(('.'.join((package, enum.name))), file_desc_proto) + for enum_value in enum.value: + self._file_desc_protos_by_symbol[ + '.'.join((package, enum_value.name))] = file_desc_proto + for extension in file_desc_proto.extension: + self._AddSymbol(('.'.join((package, extension.name))), file_desc_proto) + for service in file_desc_proto.service: + self._AddSymbol(('.'.join((package, service.name))), file_desc_proto) + + def FindFileByName(self, name): + """Finds the file descriptor proto by file name. + + Typically the file name is a relative path ending to a .proto file. The + proto with the given name will have to have been added to this database + using the Add method or else an error will be raised. + + Args: + name: The file name to find. + + Returns: + The file descriptor proto matching the name. + + Raises: + KeyError if no file by the given name was added. + """ + + return self._file_desc_protos_by_file[name] + + def FindFileContainingSymbol(self, symbol): + """Finds the file descriptor proto containing the specified symbol. + + The symbol should be a fully qualified name including the file descriptor's + package and any containing messages. Some examples: + + 'some.package.name.Message' + 'some.package.name.Message.NestedEnum' + 'some.package.name.Message.some_field' + + The file descriptor proto containing the specified symbol must be added to + this database using the Add method or else an error will be raised. + + Args: + symbol: The fully qualified symbol name. + + Returns: + The file descriptor proto containing the symbol. + + Raises: + KeyError if no file contains the specified symbol. + """ + try: + return self._file_desc_protos_by_symbol[symbol] + except KeyError: + # Fields, enum values, and nested extensions are not in + # _file_desc_protos_by_symbol. Try to find the top level + # descriptor. Non-existent nested symbol under a valid top level + # descriptor can also be found. The behavior is the same with + # protobuf C++. + top_level, _, _ = symbol.rpartition('.') + try: + return self._file_desc_protos_by_symbol[top_level] + except KeyError: + # Raise the original symbol as a KeyError for better diagnostics. + raise KeyError(symbol) + + def FindFileContainingExtension(self, extendee_name, extension_number): + # TODO(jieluo): implement this API. + return None + + def FindAllExtensionNumbers(self, extendee_name): + # TODO(jieluo): implement this API. + return [] + + def _AddSymbol(self, name, file_desc_proto): + if name in self._file_desc_protos_by_symbol: + warn_msg = ('Conflict register for file "' + file_desc_proto.name + + '": ' + name + + ' is already defined in file "' + + self._file_desc_protos_by_symbol[name].name + '"') + warnings.warn(warn_msg, RuntimeWarning) + self._file_desc_protos_by_symbol[name] = file_desc_proto + + +def _ExtractSymbols(desc_proto, package): + """Pulls out all the symbols from a descriptor proto. + + Args: + desc_proto: The proto to extract symbols from. + package: The package containing the descriptor type. + + Yields: + The fully qualified name found in the descriptor. + """ + message_name = package + '.' + desc_proto.name if package else desc_proto.name + yield message_name + for nested_type in desc_proto.nested_type: + for symbol in _ExtractSymbols(nested_type, message_name): + yield symbol + for enum_type in desc_proto.enum_type: + yield '.'.join((message_name, enum_type.name)) diff --git a/contrib/python/protobuf/py3/google/protobuf/descriptor_pool.py b/contrib/python/protobuf/py3/google/protobuf/descriptor_pool.py new file mode 100644 index 0000000000..de9100b09c --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/descriptor_pool.py @@ -0,0 +1,1271 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Provides DescriptorPool to use as a container for proto2 descriptors. + +The DescriptorPool is used in conjection with a DescriptorDatabase to maintain +a collection of protocol buffer descriptors for use when dynamically creating +message types at runtime. + +For most applications protocol buffers should be used via modules generated by +the protocol buffer compiler tool. This should only be used when the type of +protocol buffers used in an application or library cannot be predetermined. + +Below is a straightforward example on how to use this class:: + + pool = DescriptorPool() + file_descriptor_protos = [ ... ] + for file_descriptor_proto in file_descriptor_protos: + pool.Add(file_descriptor_proto) + my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType') + +The message descriptor can be used in conjunction with the message_factory +module in order to create a protocol buffer class that can be encoded and +decoded. + +If you want to get a Python class for the specified proto, use the +helper functions inside google.protobuf.message_factory +directly instead of this class. +""" + +__author__ = 'matthewtoia@google.com (Matt Toia)' + +import collections +import warnings + +from google.protobuf import descriptor +from google.protobuf import descriptor_database +from google.protobuf import text_encoding + + +_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS # pylint: disable=protected-access + + +def _Deprecated(func): + """Mark functions as deprecated.""" + + def NewFunc(*args, **kwargs): + warnings.warn( + 'Call to deprecated function %s(). Note: Do add unlinked descriptors ' + 'to descriptor_pool is wrong. Use Add() or AddSerializedFile() ' + 'instead.' % func.__name__, + category=DeprecationWarning) + return func(*args, **kwargs) + NewFunc.__name__ = func.__name__ + NewFunc.__doc__ = func.__doc__ + NewFunc.__dict__.update(func.__dict__) + return NewFunc + + +def _NormalizeFullyQualifiedName(name): + """Remove leading period from fully-qualified type name. + + Due to b/13860351 in descriptor_database.py, types in the root namespace are + generated with a leading period. This function removes that prefix. + + Args: + name (str): The fully-qualified symbol name. + + Returns: + str: The normalized fully-qualified symbol name. + """ + return name.lstrip('.') + + +def _OptionsOrNone(descriptor_proto): + """Returns the value of the field `options`, or None if it is not set.""" + if descriptor_proto.HasField('options'): + return descriptor_proto.options + else: + return None + + +def _IsMessageSetExtension(field): + return (field.is_extension and + field.containing_type.has_options and + field.containing_type.GetOptions().message_set_wire_format and + field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and + field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL) + + +class DescriptorPool(object): + """A collection of protobufs dynamically constructed by descriptor protos.""" + + if _USE_C_DESCRIPTORS: + + def __new__(cls, descriptor_db=None): + # pylint: disable=protected-access + return descriptor._message.DescriptorPool(descriptor_db) + + def __init__(self, descriptor_db=None): + """Initializes a Pool of proto buffs. + + The descriptor_db argument to the constructor is provided to allow + specialized file descriptor proto lookup code to be triggered on demand. An + example would be an implementation which will read and compile a file + specified in a call to FindFileByName() and not require the call to Add() + at all. Results from this database will be cached internally here as well. + + Args: + descriptor_db: A secondary source of file descriptors. + """ + + self._internal_db = descriptor_database.DescriptorDatabase() + self._descriptor_db = descriptor_db + self._descriptors = {} + self._enum_descriptors = {} + self._service_descriptors = {} + self._file_descriptors = {} + self._toplevel_extensions = {} + # TODO(jieluo): Remove _file_desc_by_toplevel_extension after + # maybe year 2020 for compatibility issue (with 3.4.1 only). + self._file_desc_by_toplevel_extension = {} + self._top_enum_values = {} + # We store extensions in two two-level mappings: The first key is the + # descriptor of the message being extended, the second key is the extension + # full name or its tag number. + self._extensions_by_name = collections.defaultdict(dict) + self._extensions_by_number = collections.defaultdict(dict) + + def _CheckConflictRegister(self, desc, desc_name, file_name): + """Check if the descriptor name conflicts with another of the same name. + + Args: + desc: Descriptor of a message, enum, service, extension or enum value. + desc_name (str): the full name of desc. + file_name (str): The file name of descriptor. + """ + for register, descriptor_type in [ + (self._descriptors, descriptor.Descriptor), + (self._enum_descriptors, descriptor.EnumDescriptor), + (self._service_descriptors, descriptor.ServiceDescriptor), + (self._toplevel_extensions, descriptor.FieldDescriptor), + (self._top_enum_values, descriptor.EnumValueDescriptor)]: + if desc_name in register: + old_desc = register[desc_name] + if isinstance(old_desc, descriptor.EnumValueDescriptor): + old_file = old_desc.type.file.name + else: + old_file = old_desc.file.name + + if not isinstance(desc, descriptor_type) or ( + old_file != file_name): + error_msg = ('Conflict register for file "' + file_name + + '": ' + desc_name + + ' is already defined in file "' + + old_file + '". Please fix the conflict by adding ' + 'package name on the proto file, or use different ' + 'name for the duplication.') + if isinstance(desc, descriptor.EnumValueDescriptor): + error_msg += ('\nNote: enum values appear as ' + 'siblings of the enum type instead of ' + 'children of it.') + + raise TypeError(error_msg) + + return + + def Add(self, file_desc_proto): + """Adds the FileDescriptorProto and its types to this pool. + + Args: + file_desc_proto (FileDescriptorProto): The file descriptor to add. + """ + + self._internal_db.Add(file_desc_proto) + + def AddSerializedFile(self, serialized_file_desc_proto): + """Adds the FileDescriptorProto and its types to this pool. + + Args: + serialized_file_desc_proto (bytes): A bytes string, serialization of the + :class:`FileDescriptorProto` to add. + """ + + # pylint: disable=g-import-not-at-top + from google.protobuf import descriptor_pb2 + file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString( + serialized_file_desc_proto) + self.Add(file_desc_proto) + + # Add Descriptor to descriptor pool is dreprecated. Please use Add() + # or AddSerializedFile() to add a FileDescriptorProto instead. + @_Deprecated + def AddDescriptor(self, desc): + self._AddDescriptor(desc) + + # Never call this method. It is for internal usage only. + def _AddDescriptor(self, desc): + """Adds a Descriptor to the pool, non-recursively. + + If the Descriptor contains nested messages or enums, the caller must + explicitly register them. This method also registers the FileDescriptor + associated with the message. + + Args: + desc: A Descriptor. + """ + if not isinstance(desc, descriptor.Descriptor): + raise TypeError('Expected instance of descriptor.Descriptor.') + + self._CheckConflictRegister(desc, desc.full_name, desc.file.name) + + self._descriptors[desc.full_name] = desc + self._AddFileDescriptor(desc.file) + + # Add EnumDescriptor to descriptor pool is dreprecated. Please use Add() + # or AddSerializedFile() to add a FileDescriptorProto instead. + @_Deprecated + def AddEnumDescriptor(self, enum_desc): + self._AddEnumDescriptor(enum_desc) + + # Never call this method. It is for internal usage only. + def _AddEnumDescriptor(self, enum_desc): + """Adds an EnumDescriptor to the pool. + + This method also registers the FileDescriptor associated with the enum. + + Args: + enum_desc: An EnumDescriptor. + """ + + if not isinstance(enum_desc, descriptor.EnumDescriptor): + raise TypeError('Expected instance of descriptor.EnumDescriptor.') + + file_name = enum_desc.file.name + self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name) + self._enum_descriptors[enum_desc.full_name] = enum_desc + + # Top enum values need to be indexed. + # Count the number of dots to see whether the enum is toplevel or nested + # in a message. We cannot use enum_desc.containing_type at this stage. + if enum_desc.file.package: + top_level = (enum_desc.full_name.count('.') + - enum_desc.file.package.count('.') == 1) + else: + top_level = enum_desc.full_name.count('.') == 0 + if top_level: + file_name = enum_desc.file.name + package = enum_desc.file.package + for enum_value in enum_desc.values: + full_name = _NormalizeFullyQualifiedName( + '.'.join((package, enum_value.name))) + self._CheckConflictRegister(enum_value, full_name, file_name) + self._top_enum_values[full_name] = enum_value + self._AddFileDescriptor(enum_desc.file) + + # Add ServiceDescriptor to descriptor pool is dreprecated. Please use Add() + # or AddSerializedFile() to add a FileDescriptorProto instead. + @_Deprecated + def AddServiceDescriptor(self, service_desc): + self._AddServiceDescriptor(service_desc) + + # Never call this method. It is for internal usage only. + def _AddServiceDescriptor(self, service_desc): + """Adds a ServiceDescriptor to the pool. + + Args: + service_desc: A ServiceDescriptor. + """ + + if not isinstance(service_desc, descriptor.ServiceDescriptor): + raise TypeError('Expected instance of descriptor.ServiceDescriptor.') + + self._CheckConflictRegister(service_desc, service_desc.full_name, + service_desc.file.name) + self._service_descriptors[service_desc.full_name] = service_desc + + # Add ExtensionDescriptor to descriptor pool is dreprecated. Please use Add() + # or AddSerializedFile() to add a FileDescriptorProto instead. + @_Deprecated + def AddExtensionDescriptor(self, extension): + self._AddExtensionDescriptor(extension) + + # Never call this method. It is for internal usage only. + def _AddExtensionDescriptor(self, extension): + """Adds a FieldDescriptor describing an extension to the pool. + + Args: + extension: A FieldDescriptor. + + Raises: + AssertionError: when another extension with the same number extends the + same message. + TypeError: when the specified extension is not a + descriptor.FieldDescriptor. + """ + if not (isinstance(extension, descriptor.FieldDescriptor) and + extension.is_extension): + raise TypeError('Expected an extension descriptor.') + + if extension.extension_scope is None: + self._toplevel_extensions[extension.full_name] = extension + + try: + existing_desc = self._extensions_by_number[ + extension.containing_type][extension.number] + except KeyError: + pass + else: + if extension is not existing_desc: + raise AssertionError( + 'Extensions "%s" and "%s" both try to extend message type "%s" ' + 'with field number %d.' % + (extension.full_name, existing_desc.full_name, + extension.containing_type.full_name, extension.number)) + + self._extensions_by_number[extension.containing_type][ + extension.number] = extension + self._extensions_by_name[extension.containing_type][ + extension.full_name] = extension + + # Also register MessageSet extensions with the type name. + if _IsMessageSetExtension(extension): + self._extensions_by_name[extension.containing_type][ + extension.message_type.full_name] = extension + + @_Deprecated + def AddFileDescriptor(self, file_desc): + self._InternalAddFileDescriptor(file_desc) + + # Never call this method. It is for internal usage only. + def _InternalAddFileDescriptor(self, file_desc): + """Adds a FileDescriptor to the pool, non-recursively. + + If the FileDescriptor contains messages or enums, the caller must explicitly + register them. + + Args: + file_desc: A FileDescriptor. + """ + + self._AddFileDescriptor(file_desc) + # TODO(jieluo): This is a temporary solution for FieldDescriptor.file. + # FieldDescriptor.file is added in code gen. Remove this solution after + # maybe 2020 for compatibility reason (with 3.4.1 only). + for extension in file_desc.extensions_by_name.values(): + self._file_desc_by_toplevel_extension[ + extension.full_name] = file_desc + + def _AddFileDescriptor(self, file_desc): + """Adds a FileDescriptor to the pool, non-recursively. + + If the FileDescriptor contains messages or enums, the caller must explicitly + register them. + + Args: + file_desc: A FileDescriptor. + """ + + if not isinstance(file_desc, descriptor.FileDescriptor): + raise TypeError('Expected instance of descriptor.FileDescriptor.') + self._file_descriptors[file_desc.name] = file_desc + + def FindFileByName(self, file_name): + """Gets a FileDescriptor by file name. + + Args: + file_name (str): The path to the file to get a descriptor for. + + Returns: + FileDescriptor: The descriptor for the named file. + + Raises: + KeyError: if the file cannot be found in the pool. + """ + + try: + return self._file_descriptors[file_name] + except KeyError: + pass + + try: + file_proto = self._internal_db.FindFileByName(file_name) + except KeyError as error: + if self._descriptor_db: + file_proto = self._descriptor_db.FindFileByName(file_name) + else: + raise error + if not file_proto: + raise KeyError('Cannot find a file named %s' % file_name) + return self._ConvertFileProtoToFileDescriptor(file_proto) + + def FindFileContainingSymbol(self, symbol): + """Gets the FileDescriptor for the file containing the specified symbol. + + Args: + symbol (str): The name of the symbol to search for. + + Returns: + FileDescriptor: Descriptor for the file that contains the specified + symbol. + + Raises: + KeyError: if the file cannot be found in the pool. + """ + + symbol = _NormalizeFullyQualifiedName(symbol) + try: + return self._InternalFindFileContainingSymbol(symbol) + except KeyError: + pass + + try: + # Try fallback database. Build and find again if possible. + self._FindFileContainingSymbolInDb(symbol) + return self._InternalFindFileContainingSymbol(symbol) + except KeyError: + raise KeyError('Cannot find a file containing %s' % symbol) + + def _InternalFindFileContainingSymbol(self, symbol): + """Gets the already built FileDescriptor containing the specified symbol. + + Args: + symbol (str): The name of the symbol to search for. + + Returns: + FileDescriptor: Descriptor for the file that contains the specified + symbol. + + Raises: + KeyError: if the file cannot be found in the pool. + """ + try: + return self._descriptors[symbol].file + except KeyError: + pass + + try: + return self._enum_descriptors[symbol].file + except KeyError: + pass + + try: + return self._service_descriptors[symbol].file + except KeyError: + pass + + try: + return self._top_enum_values[symbol].type.file + except KeyError: + pass + + try: + return self._file_desc_by_toplevel_extension[symbol] + except KeyError: + pass + + # Try fields, enum values and nested extensions inside a message. + top_name, _, sub_name = symbol.rpartition('.') + try: + message = self.FindMessageTypeByName(top_name) + assert (sub_name in message.extensions_by_name or + sub_name in message.fields_by_name or + sub_name in message.enum_values_by_name) + return message.file + except (KeyError, AssertionError): + raise KeyError('Cannot find a file containing %s' % symbol) + + def FindMessageTypeByName(self, full_name): + """Loads the named descriptor from the pool. + + Args: + full_name (str): The full name of the descriptor to load. + + Returns: + Descriptor: The descriptor for the named type. + + Raises: + KeyError: if the message cannot be found in the pool. + """ + + full_name = _NormalizeFullyQualifiedName(full_name) + if full_name not in self._descriptors: + self._FindFileContainingSymbolInDb(full_name) + return self._descriptors[full_name] + + def FindEnumTypeByName(self, full_name): + """Loads the named enum descriptor from the pool. + + Args: + full_name (str): The full name of the enum descriptor to load. + + Returns: + EnumDescriptor: The enum descriptor for the named type. + + Raises: + KeyError: if the enum cannot be found in the pool. + """ + + full_name = _NormalizeFullyQualifiedName(full_name) + if full_name not in self._enum_descriptors: + self._FindFileContainingSymbolInDb(full_name) + return self._enum_descriptors[full_name] + + def FindFieldByName(self, full_name): + """Loads the named field descriptor from the pool. + + Args: + full_name (str): The full name of the field descriptor to load. + + Returns: + FieldDescriptor: The field descriptor for the named field. + + Raises: + KeyError: if the field cannot be found in the pool. + """ + full_name = _NormalizeFullyQualifiedName(full_name) + message_name, _, field_name = full_name.rpartition('.') + message_descriptor = self.FindMessageTypeByName(message_name) + return message_descriptor.fields_by_name[field_name] + + def FindOneofByName(self, full_name): + """Loads the named oneof descriptor from the pool. + + Args: + full_name (str): The full name of the oneof descriptor to load. + + Returns: + OneofDescriptor: The oneof descriptor for the named oneof. + + Raises: + KeyError: if the oneof cannot be found in the pool. + """ + full_name = _NormalizeFullyQualifiedName(full_name) + message_name, _, oneof_name = full_name.rpartition('.') + message_descriptor = self.FindMessageTypeByName(message_name) + return message_descriptor.oneofs_by_name[oneof_name] + + def FindExtensionByName(self, full_name): + """Loads the named extension descriptor from the pool. + + Args: + full_name (str): The full name of the extension descriptor to load. + + Returns: + FieldDescriptor: The field descriptor for the named extension. + + Raises: + KeyError: if the extension cannot be found in the pool. + """ + full_name = _NormalizeFullyQualifiedName(full_name) + try: + # The proto compiler does not give any link between the FileDescriptor + # and top-level extensions unless the FileDescriptorProto is added to + # the DescriptorDatabase, but this can impact memory usage. + # So we registered these extensions by name explicitly. + return self._toplevel_extensions[full_name] + except KeyError: + pass + message_name, _, extension_name = full_name.rpartition('.') + try: + # Most extensions are nested inside a message. + scope = self.FindMessageTypeByName(message_name) + except KeyError: + # Some extensions are defined at file scope. + scope = self._FindFileContainingSymbolInDb(full_name) + return scope.extensions_by_name[extension_name] + + def FindExtensionByNumber(self, message_descriptor, number): + """Gets the extension of the specified message with the specified number. + + Extensions have to be registered to this pool by calling :func:`Add` or + :func:`AddExtensionDescriptor`. + + Args: + message_descriptor (Descriptor): descriptor of the extended message. + number (int): Number of the extension field. + + Returns: + FieldDescriptor: The descriptor for the extension. + + Raises: + KeyError: when no extension with the given number is known for the + specified message. + """ + try: + return self._extensions_by_number[message_descriptor][number] + except KeyError: + self._TryLoadExtensionFromDB(message_descriptor, number) + return self._extensions_by_number[message_descriptor][number] + + def FindAllExtensions(self, message_descriptor): + """Gets all the known extensions of a given message. + + Extensions have to be registered to this pool by build related + :func:`Add` or :func:`AddExtensionDescriptor`. + + Args: + message_descriptor (Descriptor): Descriptor of the extended message. + + Returns: + list[FieldDescriptor]: Field descriptors describing the extensions. + """ + # Fallback to descriptor db if FindAllExtensionNumbers is provided. + if self._descriptor_db and hasattr( + self._descriptor_db, 'FindAllExtensionNumbers'): + full_name = message_descriptor.full_name + all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name) + for number in all_numbers: + if number in self._extensions_by_number[message_descriptor]: + continue + self._TryLoadExtensionFromDB(message_descriptor, number) + + return list(self._extensions_by_number[message_descriptor].values()) + + def _TryLoadExtensionFromDB(self, message_descriptor, number): + """Try to Load extensions from descriptor db. + + Args: + message_descriptor: descriptor of the extended message. + number: the extension number that needs to be loaded. + """ + if not self._descriptor_db: + return + # Only supported when FindFileContainingExtension is provided. + if not hasattr( + self._descriptor_db, 'FindFileContainingExtension'): + return + + full_name = message_descriptor.full_name + file_proto = self._descriptor_db.FindFileContainingExtension( + full_name, number) + + if file_proto is None: + return + + try: + self._ConvertFileProtoToFileDescriptor(file_proto) + except: + warn_msg = ('Unable to load proto file %s for extension number %d.' % + (file_proto.name, number)) + warnings.warn(warn_msg, RuntimeWarning) + + def FindServiceByName(self, full_name): + """Loads the named service descriptor from the pool. + + Args: + full_name (str): The full name of the service descriptor to load. + + Returns: + ServiceDescriptor: The service descriptor for the named service. + + Raises: + KeyError: if the service cannot be found in the pool. + """ + full_name = _NormalizeFullyQualifiedName(full_name) + if full_name not in self._service_descriptors: + self._FindFileContainingSymbolInDb(full_name) + return self._service_descriptors[full_name] + + def FindMethodByName(self, full_name): + """Loads the named service method descriptor from the pool. + + Args: + full_name (str): The full name of the method descriptor to load. + + Returns: + MethodDescriptor: The method descriptor for the service method. + + Raises: + KeyError: if the method cannot be found in the pool. + """ + full_name = _NormalizeFullyQualifiedName(full_name) + service_name, _, method_name = full_name.rpartition('.') + service_descriptor = self.FindServiceByName(service_name) + return service_descriptor.methods_by_name[method_name] + + def _FindFileContainingSymbolInDb(self, symbol): + """Finds the file in descriptor DB containing the specified symbol. + + Args: + symbol (str): The name of the symbol to search for. + + Returns: + FileDescriptor: The file that contains the specified symbol. + + Raises: + KeyError: if the file cannot be found in the descriptor database. + """ + try: + file_proto = self._internal_db.FindFileContainingSymbol(symbol) + except KeyError as error: + if self._descriptor_db: + file_proto = self._descriptor_db.FindFileContainingSymbol(symbol) + else: + raise error + if not file_proto: + raise KeyError('Cannot find a file containing %s' % symbol) + return self._ConvertFileProtoToFileDescriptor(file_proto) + + def _ConvertFileProtoToFileDescriptor(self, file_proto): + """Creates a FileDescriptor from a proto or returns a cached copy. + + This method also has the side effect of loading all the symbols found in + the file into the appropriate dictionaries in the pool. + + Args: + file_proto: The proto to convert. + + Returns: + A FileDescriptor matching the passed in proto. + """ + if file_proto.name not in self._file_descriptors: + built_deps = list(self._GetDeps(file_proto.dependency)) + direct_deps = [self.FindFileByName(n) for n in file_proto.dependency] + public_deps = [direct_deps[i] for i in file_proto.public_dependency] + + file_descriptor = descriptor.FileDescriptor( + pool=self, + name=file_proto.name, + package=file_proto.package, + syntax=file_proto.syntax, + options=_OptionsOrNone(file_proto), + serialized_pb=file_proto.SerializeToString(), + dependencies=direct_deps, + public_dependencies=public_deps, + # pylint: disable=protected-access + create_key=descriptor._internal_create_key) + scope = {} + + # This loop extracts all the message and enum types from all the + # dependencies of the file_proto. This is necessary to create the + # scope of available message types when defining the passed in + # file proto. + for dependency in built_deps: + scope.update(self._ExtractSymbols( + dependency.message_types_by_name.values())) + scope.update((_PrefixWithDot(enum.full_name), enum) + for enum in dependency.enum_types_by_name.values()) + + for message_type in file_proto.message_type: + message_desc = self._ConvertMessageDescriptor( + message_type, file_proto.package, file_descriptor, scope, + file_proto.syntax) + file_descriptor.message_types_by_name[message_desc.name] = ( + message_desc) + + for enum_type in file_proto.enum_type: + file_descriptor.enum_types_by_name[enum_type.name] = ( + self._ConvertEnumDescriptor(enum_type, file_proto.package, + file_descriptor, None, scope, True)) + + for index, extension_proto in enumerate(file_proto.extension): + extension_desc = self._MakeFieldDescriptor( + extension_proto, file_proto.package, index, file_descriptor, + is_extension=True) + extension_desc.containing_type = self._GetTypeFromScope( + file_descriptor.package, extension_proto.extendee, scope) + self._SetFieldType(extension_proto, extension_desc, + file_descriptor.package, scope) + file_descriptor.extensions_by_name[extension_desc.name] = ( + extension_desc) + self._file_desc_by_toplevel_extension[extension_desc.full_name] = ( + file_descriptor) + + for desc_proto in file_proto.message_type: + self._SetAllFieldTypes(file_proto.package, desc_proto, scope) + + if file_proto.package: + desc_proto_prefix = _PrefixWithDot(file_proto.package) + else: + desc_proto_prefix = '' + + for desc_proto in file_proto.message_type: + desc = self._GetTypeFromScope( + desc_proto_prefix, desc_proto.name, scope) + file_descriptor.message_types_by_name[desc_proto.name] = desc + + for index, service_proto in enumerate(file_proto.service): + file_descriptor.services_by_name[service_proto.name] = ( + self._MakeServiceDescriptor(service_proto, index, scope, + file_proto.package, file_descriptor)) + + self.Add(file_proto) + self._file_descriptors[file_proto.name] = file_descriptor + + # Add extensions to the pool + file_desc = self._file_descriptors[file_proto.name] + for extension in file_desc.extensions_by_name.values(): + self._AddExtensionDescriptor(extension) + for message_type in file_desc.message_types_by_name.values(): + for extension in message_type.extensions: + self._AddExtensionDescriptor(extension) + + return file_desc + + def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None, + scope=None, syntax=None): + """Adds the proto to the pool in the specified package. + + Args: + desc_proto: The descriptor_pb2.DescriptorProto protobuf message. + package: The package the proto should be located in. + file_desc: The file containing this message. + scope: Dict mapping short and full symbols to message and enum types. + syntax: string indicating syntax of the file ("proto2" or "proto3") + + Returns: + The added descriptor. + """ + + if package: + desc_name = '.'.join((package, desc_proto.name)) + else: + desc_name = desc_proto.name + + if file_desc is None: + file_name = None + else: + file_name = file_desc.name + + if scope is None: + scope = {} + + nested = [ + self._ConvertMessageDescriptor( + nested, desc_name, file_desc, scope, syntax) + for nested in desc_proto.nested_type] + enums = [ + self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, + scope, False) + for enum in desc_proto.enum_type] + fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc) + for index, field in enumerate(desc_proto.field)] + extensions = [ + self._MakeFieldDescriptor(extension, desc_name, index, file_desc, + is_extension=True) + for index, extension in enumerate(desc_proto.extension)] + oneofs = [ + # pylint: disable=g-complex-comprehension + descriptor.OneofDescriptor(desc.name, '.'.join((desc_name, desc.name)), + index, None, [], desc.options, + # pylint: disable=protected-access + create_key=descriptor._internal_create_key) + for index, desc in enumerate(desc_proto.oneof_decl)] + extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range] + if extension_ranges: + is_extendable = True + else: + is_extendable = False + desc = descriptor.Descriptor( + name=desc_proto.name, + full_name=desc_name, + filename=file_name, + containing_type=None, + fields=fields, + oneofs=oneofs, + nested_types=nested, + enum_types=enums, + extensions=extensions, + options=_OptionsOrNone(desc_proto), + is_extendable=is_extendable, + extension_ranges=extension_ranges, + file=file_desc, + serialized_start=None, + serialized_end=None, + syntax=syntax, + # pylint: disable=protected-access + create_key=descriptor._internal_create_key) + for nested in desc.nested_types: + nested.containing_type = desc + for enum in desc.enum_types: + enum.containing_type = desc + for field_index, field_desc in enumerate(desc_proto.field): + if field_desc.HasField('oneof_index'): + oneof_index = field_desc.oneof_index + oneofs[oneof_index].fields.append(fields[field_index]) + fields[field_index].containing_oneof = oneofs[oneof_index] + + scope[_PrefixWithDot(desc_name)] = desc + self._CheckConflictRegister(desc, desc.full_name, desc.file.name) + self._descriptors[desc_name] = desc + return desc + + def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None, + containing_type=None, scope=None, top_level=False): + """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf. + + Args: + enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message. + package: Optional package name for the new message EnumDescriptor. + file_desc: The file containing the enum descriptor. + containing_type: The type containing this enum. + scope: Scope containing available types. + top_level: If True, the enum is a top level symbol. If False, the enum + is defined inside a message. + + Returns: + The added descriptor + """ + + if package: + enum_name = '.'.join((package, enum_proto.name)) + else: + enum_name = enum_proto.name + + if file_desc is None: + file_name = None + else: + file_name = file_desc.name + + values = [self._MakeEnumValueDescriptor(value, index) + for index, value in enumerate(enum_proto.value)] + desc = descriptor.EnumDescriptor(name=enum_proto.name, + full_name=enum_name, + filename=file_name, + file=file_desc, + values=values, + containing_type=containing_type, + options=_OptionsOrNone(enum_proto), + # pylint: disable=protected-access + create_key=descriptor._internal_create_key) + scope['.%s' % enum_name] = desc + self._CheckConflictRegister(desc, desc.full_name, desc.file.name) + self._enum_descriptors[enum_name] = desc + + # Add top level enum values. + if top_level: + for value in values: + full_name = _NormalizeFullyQualifiedName( + '.'.join((package, value.name))) + self._CheckConflictRegister(value, full_name, file_name) + self._top_enum_values[full_name] = value + + return desc + + def _MakeFieldDescriptor(self, field_proto, message_name, index, + file_desc, is_extension=False): + """Creates a field descriptor from a FieldDescriptorProto. + + For message and enum type fields, this method will do a look up + in the pool for the appropriate descriptor for that type. If it + is unavailable, it will fall back to the _source function to + create it. If this type is still unavailable, construction will + fail. + + Args: + field_proto: The proto describing the field. + message_name: The name of the containing message. + index: Index of the field + file_desc: The file containing the field descriptor. + is_extension: Indication that this field is for an extension. + + Returns: + An initialized FieldDescriptor object + """ + + if message_name: + full_name = '.'.join((message_name, field_proto.name)) + else: + full_name = field_proto.name + + return descriptor.FieldDescriptor( + name=field_proto.name, + full_name=full_name, + index=index, + number=field_proto.number, + type=field_proto.type, + cpp_type=None, + message_type=None, + enum_type=None, + containing_type=None, + label=field_proto.label, + has_default_value=False, + default_value=None, + is_extension=is_extension, + extension_scope=None, + options=_OptionsOrNone(field_proto), + file=file_desc, + # pylint: disable=protected-access + create_key=descriptor._internal_create_key) + + def _SetAllFieldTypes(self, package, desc_proto, scope): + """Sets all the descriptor's fields's types. + + This method also sets the containing types on any extensions. + + Args: + package: The current package of desc_proto. + desc_proto: The message descriptor to update. + scope: Enclosing scope of available types. + """ + + package = _PrefixWithDot(package) + + main_desc = self._GetTypeFromScope(package, desc_proto.name, scope) + + if package == '.': + nested_package = _PrefixWithDot(desc_proto.name) + else: + nested_package = '.'.join([package, desc_proto.name]) + + for field_proto, field_desc in zip(desc_proto.field, main_desc.fields): + self._SetFieldType(field_proto, field_desc, nested_package, scope) + + for extension_proto, extension_desc in ( + zip(desc_proto.extension, main_desc.extensions)): + extension_desc.containing_type = self._GetTypeFromScope( + nested_package, extension_proto.extendee, scope) + self._SetFieldType(extension_proto, extension_desc, nested_package, scope) + + for nested_type in desc_proto.nested_type: + self._SetAllFieldTypes(nested_package, nested_type, scope) + + def _SetFieldType(self, field_proto, field_desc, package, scope): + """Sets the field's type, cpp_type, message_type and enum_type. + + Args: + field_proto: Data about the field in proto format. + field_desc: The descriptor to modify. + package: The package the field's container is in. + scope: Enclosing scope of available types. + """ + if field_proto.type_name: + desc = self._GetTypeFromScope(package, field_proto.type_name, scope) + else: + desc = None + + if not field_proto.HasField('type'): + if isinstance(desc, descriptor.Descriptor): + field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE + else: + field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM + + field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType( + field_proto.type) + + if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE + or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP): + field_desc.message_type = desc + + if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: + field_desc.enum_type = desc + + if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED: + field_desc.has_default_value = False + field_desc.default_value = [] + elif field_proto.HasField('default_value'): + field_desc.has_default_value = True + if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or + field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT): + field_desc.default_value = float(field_proto.default_value) + elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING: + field_desc.default_value = field_proto.default_value + elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL: + field_desc.default_value = field_proto.default_value.lower() == 'true' + elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: + field_desc.default_value = field_desc.enum_type.values_by_name[ + field_proto.default_value].number + elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: + field_desc.default_value = text_encoding.CUnescape( + field_proto.default_value) + elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE: + field_desc.default_value = None + else: + # All other types are of the "int" type. + field_desc.default_value = int(field_proto.default_value) + else: + field_desc.has_default_value = False + if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or + field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT): + field_desc.default_value = 0.0 + elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING: + field_desc.default_value = u'' + elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL: + field_desc.default_value = False + elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: + field_desc.default_value = field_desc.enum_type.values[0].number + elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: + field_desc.default_value = b'' + elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE: + field_desc.default_value = None + else: + # All other types are of the "int" type. + field_desc.default_value = 0 + + field_desc.type = field_proto.type + + def _MakeEnumValueDescriptor(self, value_proto, index): + """Creates a enum value descriptor object from a enum value proto. + + Args: + value_proto: The proto describing the enum value. + index: The index of the enum value. + + Returns: + An initialized EnumValueDescriptor object. + """ + + return descriptor.EnumValueDescriptor( + name=value_proto.name, + index=index, + number=value_proto.number, + options=_OptionsOrNone(value_proto), + type=None, + # pylint: disable=protected-access + create_key=descriptor._internal_create_key) + + def _MakeServiceDescriptor(self, service_proto, service_index, scope, + package, file_desc): + """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto. + + Args: + service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message. + service_index: The index of the service in the File. + scope: Dict mapping short and full symbols to message and enum types. + package: Optional package name for the new message EnumDescriptor. + file_desc: The file containing the service descriptor. + + Returns: + The added descriptor. + """ + + if package: + service_name = '.'.join((package, service_proto.name)) + else: + service_name = service_proto.name + + methods = [self._MakeMethodDescriptor(method_proto, service_name, package, + scope, index) + for index, method_proto in enumerate(service_proto.method)] + desc = descriptor.ServiceDescriptor( + name=service_proto.name, + full_name=service_name, + index=service_index, + methods=methods, + options=_OptionsOrNone(service_proto), + file=file_desc, + # pylint: disable=protected-access + create_key=descriptor._internal_create_key) + self._CheckConflictRegister(desc, desc.full_name, desc.file.name) + self._service_descriptors[service_name] = desc + return desc + + def _MakeMethodDescriptor(self, method_proto, service_name, package, scope, + index): + """Creates a method descriptor from a MethodDescriptorProto. + + Args: + method_proto: The proto describing the method. + service_name: The name of the containing service. + package: Optional package name to look up for types. + scope: Scope containing available types. + index: Index of the method in the service. + + Returns: + An initialized MethodDescriptor object. + """ + full_name = '.'.join((service_name, method_proto.name)) + input_type = self._GetTypeFromScope( + package, method_proto.input_type, scope) + output_type = self._GetTypeFromScope( + package, method_proto.output_type, scope) + return descriptor.MethodDescriptor( + name=method_proto.name, + full_name=full_name, + index=index, + containing_service=None, + input_type=input_type, + output_type=output_type, + options=_OptionsOrNone(method_proto), + # pylint: disable=protected-access + create_key=descriptor._internal_create_key) + + def _ExtractSymbols(self, descriptors): + """Pulls out all the symbols from descriptor protos. + + Args: + descriptors: The messages to extract descriptors from. + Yields: + A two element tuple of the type name and descriptor object. + """ + + for desc in descriptors: + yield (_PrefixWithDot(desc.full_name), desc) + for symbol in self._ExtractSymbols(desc.nested_types): + yield symbol + for enum in desc.enum_types: + yield (_PrefixWithDot(enum.full_name), enum) + + def _GetDeps(self, dependencies): + """Recursively finds dependencies for file protos. + + Args: + dependencies: The names of the files being depended on. + + Yields: + Each direct and indirect dependency. + """ + + for dependency in dependencies: + dep_desc = self.FindFileByName(dependency) + yield dep_desc + for parent_dep in dep_desc.dependencies: + yield parent_dep + + def _GetTypeFromScope(self, package, type_name, scope): + """Finds a given type name in the current scope. + + Args: + package: The package the proto should be located in. + type_name: The name of the type to be found in the scope. + scope: Dict mapping short and full symbols to message and enum types. + + Returns: + The descriptor for the requested type. + """ + if type_name not in scope: + components = _PrefixWithDot(package).split('.') + while components: + possible_match = '.'.join(components + [type_name]) + if possible_match in scope: + type_name = possible_match + break + else: + components.pop(-1) + return scope[type_name] + + +def _PrefixWithDot(name): + return name if name.startswith('.') else '.%s' % name + + +if _USE_C_DESCRIPTORS: + # TODO(amauryfa): This pool could be constructed from Python code, when we + # support a flag like 'use_cpp_generated_pool=True'. + # pylint: disable=protected-access + _DEFAULT = descriptor._message.default_pool +else: + _DEFAULT = DescriptorPool() + + +def Default(): + return _DEFAULT diff --git a/contrib/python/protobuf/py3/google/protobuf/internal/__init__.py b/contrib/python/protobuf/py3/google/protobuf/internal/__init__.py new file mode 100644 index 0000000000..7d2e571a14 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/internal/__init__.py @@ -0,0 +1,30 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + diff --git a/contrib/python/protobuf/py3/google/protobuf/internal/_parameterized.py b/contrib/python/protobuf/py3/google/protobuf/internal/_parameterized.py new file mode 100755 index 0000000000..4cba1d479d --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/internal/_parameterized.py @@ -0,0 +1,449 @@ +#! /usr/bin/env python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Adds support for parameterized tests to Python's unittest TestCase class. + +A parameterized test is a method in a test case that is invoked with different +argument tuples. + +A simple example: + + class AdditionExample(parameterized.TestCase): + @parameterized.parameters( + (1, 2, 3), + (4, 5, 9), + (1, 1, 3)) + def testAddition(self, op1, op2, result): + self.assertEqual(result, op1 + op2) + + +Each invocation is a separate test case and properly isolated just +like a normal test method, with its own setUp/tearDown cycle. In the +example above, there are three separate testcases, one of which will +fail due to an assertion error (1 + 1 != 3). + +Parameters for individual test cases can be tuples (with positional parameters) +or dictionaries (with named parameters): + + class AdditionExample(parameterized.TestCase): + @parameterized.parameters( + {'op1': 1, 'op2': 2, 'result': 3}, + {'op1': 4, 'op2': 5, 'result': 9}, + ) + def testAddition(self, op1, op2, result): + self.assertEqual(result, op1 + op2) + +If a parameterized test fails, the error message will show the +original test name (which is modified internally) and the arguments +for the specific invocation, which are part of the string returned by +the shortDescription() method on test cases. + +The id method of the test, used internally by the unittest framework, +is also modified to show the arguments. To make sure that test names +stay the same across several invocations, object representations like + + >>> class Foo(object): + ... pass + >>> repr(Foo()) + '<__main__.Foo object at 0x23d8610>' + +are turned into '<__main__.Foo>'. For even more descriptive names, +especially in test logs, you can use the named_parameters decorator. In +this case, only tuples are supported, and the first parameters has to +be a string (or an object that returns an apt name when converted via +str()): + + class NamedExample(parameterized.TestCase): + @parameterized.named_parameters( + ('Normal', 'aa', 'aaa', True), + ('EmptyPrefix', '', 'abc', True), + ('BothEmpty', '', '', True)) + def testStartsWith(self, prefix, string, result): + self.assertEqual(result, strings.startswith(prefix)) + +Named tests also have the benefit that they can be run individually +from the command line: + + $ testmodule.py NamedExample.testStartsWithNormal + . + -------------------------------------------------------------------- + Ran 1 test in 0.000s + + OK + +Parameterized Classes +===================== +If invocation arguments are shared across test methods in a single +TestCase class, instead of decorating all test methods +individually, the class itself can be decorated: + + @parameterized.parameters( + (1, 2, 3) + (4, 5, 9)) + class ArithmeticTest(parameterized.TestCase): + def testAdd(self, arg1, arg2, result): + self.assertEqual(arg1 + arg2, result) + + def testSubtract(self, arg2, arg2, result): + self.assertEqual(result - arg1, arg2) + +Inputs from Iterables +===================== +If parameters should be shared across several test cases, or are dynamically +created from other sources, a single non-tuple iterable can be passed into +the decorator. This iterable will be used to obtain the test cases: + + class AdditionExample(parameterized.TestCase): + @parameterized.parameters( + c.op1, c.op2, c.result for c in testcases + ) + def testAddition(self, op1, op2, result): + self.assertEqual(result, op1 + op2) + + +Single-Argument Test Methods +============================ +If a test method takes only one argument, the single argument does not need to +be wrapped into a tuple: + + class NegativeNumberExample(parameterized.TestCase): + @parameterized.parameters( + -1, -3, -4, -5 + ) + def testIsNegative(self, arg): + self.assertTrue(IsNegative(arg)) +""" + +__author__ = 'tmarek@google.com (Torsten Marek)' + +import functools +import re +import types +try: + import unittest2 as unittest +except ImportError: + import unittest +import uuid + +import six + +try: + # Since python 3 + import collections.abc as collections_abc +except ImportError: + # Won't work after python 3.8 + import collections as collections_abc + +ADDR_RE = re.compile(r'\<([a-zA-Z0-9_\-\.]+) object at 0x[a-fA-F0-9]+\>') +_SEPARATOR = uuid.uuid1().hex +_FIRST_ARG = object() +_ARGUMENT_REPR = object() + + +def _CleanRepr(obj): + return ADDR_RE.sub(r'<\1>', repr(obj)) + + +# Helper function formerly from the unittest module, removed from it in +# Python 2.7. +def _StrClass(cls): + return '%s.%s' % (cls.__module__, cls.__name__) + + +def _NonStringIterable(obj): + return (isinstance(obj, collections_abc.Iterable) and not + isinstance(obj, six.string_types)) + + +def _FormatParameterList(testcase_params): + if isinstance(testcase_params, collections_abc.Mapping): + return ', '.join('%s=%s' % (argname, _CleanRepr(value)) + for argname, value in testcase_params.items()) + elif _NonStringIterable(testcase_params): + return ', '.join(map(_CleanRepr, testcase_params)) + else: + return _FormatParameterList((testcase_params,)) + + +class _ParameterizedTestIter(object): + """Callable and iterable class for producing new test cases.""" + + def __init__(self, test_method, testcases, naming_type): + """Returns concrete test functions for a test and a list of parameters. + + The naming_type is used to determine the name of the concrete + functions as reported by the unittest framework. If naming_type is + _FIRST_ARG, the testcases must be tuples, and the first element must + have a string representation that is a valid Python identifier. + + Args: + test_method: The decorated test method. + testcases: (list of tuple/dict) A list of parameter + tuples/dicts for individual test invocations. + naming_type: The test naming type, either _NAMED or _ARGUMENT_REPR. + """ + self._test_method = test_method + self.testcases = testcases + self._naming_type = naming_type + + def __call__(self, *args, **kwargs): + raise RuntimeError('You appear to be running a parameterized test case ' + 'without having inherited from parameterized.' + 'TestCase. This is bad because none of ' + 'your test cases are actually being run.') + + def __iter__(self): + test_method = self._test_method + naming_type = self._naming_type + + def MakeBoundParamTest(testcase_params): + @functools.wraps(test_method) + def BoundParamTest(self): + if isinstance(testcase_params, collections_abc.Mapping): + test_method(self, **testcase_params) + elif _NonStringIterable(testcase_params): + test_method(self, *testcase_params) + else: + test_method(self, testcase_params) + + if naming_type is _FIRST_ARG: + # Signal the metaclass that the name of the test function is unique + # and descriptive. + BoundParamTest.__x_use_name__ = True + BoundParamTest.__name__ += str(testcase_params[0]) + testcase_params = testcase_params[1:] + elif naming_type is _ARGUMENT_REPR: + # __x_extra_id__ is used to pass naming information to the __new__ + # method of TestGeneratorMetaclass. + # The metaclass will make sure to create a unique, but nondescriptive + # name for this test. + BoundParamTest.__x_extra_id__ = '(%s)' % ( + _FormatParameterList(testcase_params),) + else: + raise RuntimeError('%s is not a valid naming type.' % (naming_type,)) + + BoundParamTest.__doc__ = '%s(%s)' % ( + BoundParamTest.__name__, _FormatParameterList(testcase_params)) + if test_method.__doc__: + BoundParamTest.__doc__ += '\n%s' % (test_method.__doc__,) + return BoundParamTest + return (MakeBoundParamTest(c) for c in self.testcases) + + +def _IsSingletonList(testcases): + """True iff testcases contains only a single non-tuple element.""" + return len(testcases) == 1 and not isinstance(testcases[0], tuple) + + +def _ModifyClass(class_object, testcases, naming_type): + assert not getattr(class_object, '_id_suffix', None), ( + 'Cannot add parameters to %s,' + ' which already has parameterized methods.' % (class_object,)) + class_object._id_suffix = id_suffix = {} + # We change the size of __dict__ while we iterate over it, + # which Python 3.x will complain about, so use copy(). + for name, obj in class_object.__dict__.copy().items(): + if (name.startswith(unittest.TestLoader.testMethodPrefix) + and isinstance(obj, types.FunctionType)): + delattr(class_object, name) + methods = {} + _UpdateClassDictForParamTestCase( + methods, id_suffix, name, + _ParameterizedTestIter(obj, testcases, naming_type)) + for name, meth in methods.items(): + setattr(class_object, name, meth) + + +def _ParameterDecorator(naming_type, testcases): + """Implementation of the parameterization decorators. + + Args: + naming_type: The naming type. + testcases: Testcase parameters. + + Returns: + A function for modifying the decorated object. + """ + def _Apply(obj): + if isinstance(obj, type): + _ModifyClass( + obj, + list(testcases) if not isinstance(testcases, collections_abc.Sequence) + else testcases, + naming_type) + return obj + else: + return _ParameterizedTestIter(obj, testcases, naming_type) + + if _IsSingletonList(testcases): + assert _NonStringIterable(testcases[0]), ( + 'Single parameter argument must be a non-string iterable') + testcases = testcases[0] + + return _Apply + + +def parameters(*testcases): # pylint: disable=invalid-name + """A decorator for creating parameterized tests. + + See the module docstring for a usage example. + Args: + *testcases: Parameters for the decorated method, either a single + iterable, or a list of tuples/dicts/objects (for tests + with only one argument). + + Returns: + A test generator to be handled by TestGeneratorMetaclass. + """ + return _ParameterDecorator(_ARGUMENT_REPR, testcases) + + +def named_parameters(*testcases): # pylint: disable=invalid-name + """A decorator for creating parameterized tests. + + See the module docstring for a usage example. The first element of + each parameter tuple should be a string and will be appended to the + name of the test method. + + Args: + *testcases: Parameters for the decorated method, either a single + iterable, or a list of tuples. + + Returns: + A test generator to be handled by TestGeneratorMetaclass. + """ + return _ParameterDecorator(_FIRST_ARG, testcases) + + +class TestGeneratorMetaclass(type): + """Metaclass for test cases with test generators. + + A test generator is an iterable in a testcase that produces callables. These + callables must be single-argument methods. These methods are injected into + the class namespace and the original iterable is removed. If the name of the + iterable conforms to the test pattern, the injected methods will be picked + up as tests by the unittest framework. + + In general, it is supposed to be used in conjunction with the + parameters decorator. + """ + + def __new__(mcs, class_name, bases, dct): + dct['_id_suffix'] = id_suffix = {} + for name, obj in dct.items(): + if (name.startswith(unittest.TestLoader.testMethodPrefix) and + _NonStringIterable(obj)): + iterator = iter(obj) + dct.pop(name) + _UpdateClassDictForParamTestCase(dct, id_suffix, name, iterator) + + return type.__new__(mcs, class_name, bases, dct) + + +def _UpdateClassDictForParamTestCase(dct, id_suffix, name, iterator): + """Adds individual test cases to a dictionary. + + Args: + dct: The target dictionary. + id_suffix: The dictionary for mapping names to test IDs. + name: The original name of the test case. + iterator: The iterator generating the individual test cases. + """ + for idx, func in enumerate(iterator): + assert callable(func), 'Test generators must yield callables, got %r' % ( + func,) + if getattr(func, '__x_use_name__', False): + new_name = func.__name__ + else: + new_name = '%s%s%d' % (name, _SEPARATOR, idx) + assert new_name not in dct, ( + 'Name of parameterized test case "%s" not unique' % (new_name,)) + dct[new_name] = func + id_suffix[new_name] = getattr(func, '__x_extra_id__', '') + + +class TestCase(unittest.TestCase): + """Base class for test cases using the parameters decorator.""" + __metaclass__ = TestGeneratorMetaclass + + def _OriginalName(self): + return self._testMethodName.split(_SEPARATOR)[0] + + def __str__(self): + return '%s (%s)' % (self._OriginalName(), _StrClass(self.__class__)) + + def id(self): # pylint: disable=invalid-name + """Returns the descriptive ID of the test. + + This is used internally by the unittesting framework to get a name + for the test to be used in reports. + + Returns: + The test id. + """ + return '%s.%s%s' % (_StrClass(self.__class__), + self._OriginalName(), + self._id_suffix.get(self._testMethodName, '')) + + +def CoopTestCase(other_base_class): + """Returns a new base class with a cooperative metaclass base. + + This enables the TestCase to be used in combination + with other base classes that have custom metaclasses, such as + mox.MoxTestBase. + + Only works with metaclasses that do not override type.__new__. + + Example: + + import google3 + import mox + + from google3.testing.pybase import parameterized + + class ExampleTest(parameterized.CoopTestCase(mox.MoxTestBase)): + ... + + Args: + other_base_class: (class) A test case base class. + + Returns: + A new class object. + """ + metaclass = type( + 'CoopMetaclass', + (other_base_class.__metaclass__, + TestGeneratorMetaclass), {}) + return metaclass( + 'CoopTestCase', + (other_base_class, TestCase), {}) diff --git a/contrib/python/protobuf/py3/google/protobuf/internal/api_implementation.cc b/contrib/python/protobuf/py3/google/protobuf/internal/api_implementation.cc new file mode 100644 index 0000000000..2a7f41d2d6 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/internal/api_implementation.cc @@ -0,0 +1,127 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include <Python.h> + +namespace google { +namespace protobuf { +namespace python { + +// Version constant. +// This is either 0 for python, 1 for CPP V1, 2 for CPP V2. +// +// 0 is default and is equivalent to +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python +// +// 1 is set with -DPYTHON_PROTO2_CPP_IMPL_V1 and is equivalent to +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp +// and +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=1 +// +// 2 is set with -DPYTHON_PROTO2_CPP_IMPL_V2 and is equivalent to +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp +// and +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2 +#ifdef PYTHON_PROTO2_CPP_IMPL_V1 +#error "PYTHON_PROTO2_CPP_IMPL_V1 is no longer supported." +#else +#ifdef PYTHON_PROTO2_CPP_IMPL_V2 +static int kImplVersion = 2; +#else +#ifdef PYTHON_PROTO2_PYTHON_IMPL +static int kImplVersion = 0; +#else + +static int kImplVersion = -1; // -1 means "Unspecified by compiler flags". + +#endif // PYTHON_PROTO2_PYTHON_IMPL +#endif // PYTHON_PROTO2_CPP_IMPL_V2 +#endif // PYTHON_PROTO2_CPP_IMPL_V1 + +static const char* kImplVersionName = "api_version"; + +static const char* kModuleName = "_api_implementation"; +static const char kModuleDocstring[] = + "_api_implementation is a module that exposes compile-time constants that\n" + "determine the default API implementation to use for Python proto2.\n" + "\n" + "It complements api_implementation.py by setting defaults using " + "compile-time\n" + "constants defined in C, such that one can set defaults at compilation\n" + "(e.g. with blaze flag --copt=-DPYTHON_PROTO2_CPP_IMPL_V2)."; + +#if PY_MAJOR_VERSION >= 3 +static struct PyModuleDef _module = {PyModuleDef_HEAD_INIT, + kModuleName, + kModuleDocstring, + -1, + NULL, + NULL, + NULL, + NULL, + NULL}; +#define INITFUNC PyInit__api_implementation +#define INITFUNC_ERRORVAL NULL +#else +#define INITFUNC init_api_implementation +#define INITFUNC_ERRORVAL +#endif + +extern "C" { +PyMODINIT_FUNC INITFUNC() { +#if PY_MAJOR_VERSION >= 3 + PyObject* module = PyModule_Create(&_module); +#else + PyObject* module = Py_InitModule3(const_cast<char*>(kModuleName), NULL, + const_cast<char*>(kModuleDocstring)); +#endif + if (module == NULL) { + return INITFUNC_ERRORVAL; + } + + // Adds the module variable "api_version". + if (PyModule_AddIntConstant(module, const_cast<char*>(kImplVersionName), + kImplVersion)) +#if PY_MAJOR_VERSION < 3 + return; +#else + { + Py_DECREF(module); + return NULL; + } + + return module; +#endif +} +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/contrib/python/protobuf/py3/google/protobuf/internal/api_implementation.py b/contrib/python/protobuf/py3/google/protobuf/internal/api_implementation.py new file mode 100644 index 0000000000..be1af7df6b --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/internal/api_implementation.py @@ -0,0 +1,159 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Determine which implementation of the protobuf API is used in this process. +""" + +import os +import sys +import warnings + +try: + # pylint: disable=g-import-not-at-top + from google.protobuf.internal import _api_implementation + # The compile-time constants in the _api_implementation module can be used to + # switch to a certain implementation of the Python API at build time. + _api_version = _api_implementation.api_version + _proto_extension_modules_exist_in_build = True +except ImportError: + _api_version = -1 # Unspecified by compiler flags. + _proto_extension_modules_exist_in_build = False + +if _api_version == 1: + raise ValueError('api_version=1 is no longer supported.') +if _api_version < 0: # Still unspecified? + try: + # The presence of this module in a build allows the proto implementation to + # be upgraded merely via build deps rather than a compiler flag or the + # runtime environment variable. + # pylint: disable=g-import-not-at-top + from google.protobuf import _use_fast_cpp_protos + # Work around a known issue in the classic bootstrap .par import hook. + if not _use_fast_cpp_protos: + raise ImportError('_use_fast_cpp_protos import succeeded but was None') + del _use_fast_cpp_protos + _api_version = 2 + from google.protobuf import use_pure_python + raise RuntimeError( + 'Conflicting deps on both :use_fast_cpp_protos and :use_pure_python.\n' + ' go/build_deps_on_BOTH_use_fast_cpp_protos_AND_use_pure_python\n' + 'This should be impossible via a link error at build time...') + except ImportError: + try: + # pylint: disable=g-import-not-at-top + from google.protobuf import use_pure_python + del use_pure_python # Avoids a pylint error and namespace pollution. + _api_version = 0 + except ImportError: + # TODO(b/74017912): It's unsafe to enable :use_fast_cpp_protos by default; + # it can cause data loss if you have any Python-only extensions to any + # message passed back and forth with C++ code. + # + # TODO(b/17427486): Once that bug is fixed, we want to make both Python 2 + # and Python 3 default to `_api_version = 2` (C++ implementation V2). + pass + +_default_implementation_type = ( + 'python' if _api_version <= 0 else 'cpp') + +# This environment variable can be used to switch to a certain implementation +# of the Python API, overriding the compile-time constants in the +# _api_implementation module. Right now only 'python' and 'cpp' are valid +# values. Any other value will be ignored. +_implementation_type = os.getenv('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION', + _default_implementation_type) + +if _implementation_type != 'python': + _implementation_type = 'cpp' + +if 'PyPy' in sys.version and _implementation_type == 'cpp': + warnings.warn('PyPy does not work yet with cpp protocol buffers. ' + 'Falling back to the python implementation.') + _implementation_type = 'python' + +# This environment variable can be used to switch between the two +# 'cpp' implementations, overriding the compile-time constants in the +# _api_implementation module. Right now only '2' is supported. Any other +# value will cause an error to be raised. +_implementation_version_str = os.getenv( + 'PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', '2') + +if _implementation_version_str != '2': + raise ValueError( + 'unsupported PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION: "' + + _implementation_version_str + '" (supported versions: 2)' + ) + +_implementation_version = int(_implementation_version_str) + + +# Detect if serialization should be deterministic by default +try: + # The presence of this module in a build allows the proto implementation to + # be upgraded merely via build deps. + # + # NOTE: Merely importing this automatically enables deterministic proto + # serialization for C++ code, but we still need to export it as a boolean so + # that we can do the same for `_implementation_type == 'python'`. + # + # NOTE2: It is possible for C++ code to enable deterministic serialization by + # default _without_ affecting Python code, if the C++ implementation is not in + # use by this module. That is intended behavior, so we don't actually expose + # this boolean outside of this module. + # + # pylint: disable=g-import-not-at-top,unused-import + from google.protobuf import enable_deterministic_proto_serialization + _python_deterministic_proto_serialization = True +except ImportError: + _python_deterministic_proto_serialization = False + + +# Usage of this function is discouraged. Clients shouldn't care which +# implementation of the API is in use. Note that there is no guarantee +# that differences between APIs will be maintained. +# Please don't use this function if possible. +def Type(): + return _implementation_type + + +def _SetType(implementation_type): + """Never use! Only for protobuf benchmark.""" + global _implementation_type + _implementation_type = implementation_type + + +# See comment on 'Type' above. +def Version(): + return _implementation_version + + +# For internal use only +def IsPythonDefaultSerializationDeterministic(): + return _python_deterministic_proto_serialization diff --git a/contrib/python/protobuf/py3/google/protobuf/internal/containers.py b/contrib/python/protobuf/py3/google/protobuf/internal/containers.py new file mode 100644 index 0000000000..92793490bb --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/internal/containers.py @@ -0,0 +1,785 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Contains container classes to represent different protocol buffer types. + +This file defines container classes which represent categories of protocol +buffer field types which need extra maintenance. Currently these categories +are: + +- Repeated scalar fields - These are all repeated fields which aren't + composite (e.g. they are of simple types like int32, string, etc). +- Repeated composite fields - Repeated fields which are composite. This + includes groups and nested messages. +""" + +__author__ = 'petar@google.com (Petar Petrov)' + +import sys +try: + # This fallback applies for all versions of Python before 3.3 + import collections.abc as collections_abc +except ImportError: + import collections as collections_abc + +if sys.version_info[0] < 3: + # We would use collections_abc.MutableMapping all the time, but in Python 2 + # it doesn't define __slots__. This causes two significant problems: + # + # 1. we can't disallow arbitrary attribute assignment, even if our derived + # classes *do* define __slots__. + # + # 2. we can't safely derive a C type from it without __slots__ defined (the + # interpreter expects to find a dict at tp_dictoffset, which we can't + # robustly provide. And we don't want an instance dict anyway. + # + # So this is the Python 2.7 definition of Mapping/MutableMapping functions + # verbatim, except that: + # 1. We declare __slots__. + # 2. We don't declare this as a virtual base class. The classes defined + # in collections_abc are the interesting base classes, not us. + # + # Note: deriving from object is critical. It is the only thing that makes + # this a true type, allowing us to derive from it in C++ cleanly and making + # __slots__ properly disallow arbitrary element assignment. + + class Mapping(object): + __slots__ = () + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def __contains__(self, key): + try: + self[key] + except KeyError: + return False + else: + return True + + def iterkeys(self): + return iter(self) + + def itervalues(self): + for key in self: + yield self[key] + + def iteritems(self): + for key in self: + yield (key, self[key]) + + def keys(self): + return list(self) + + def items(self): + return [(key, self[key]) for key in self] + + def values(self): + return [self[key] for key in self] + + # Mappings are not hashable by default, but subclasses can change this + __hash__ = None + + def __eq__(self, other): + if not isinstance(other, collections_abc.Mapping): + return NotImplemented + return dict(self.items()) == dict(other.items()) + + def __ne__(self, other): + return not (self == other) + + class MutableMapping(Mapping): + __slots__ = () + + __marker = object() + + def pop(self, key, default=__marker): + try: + value = self[key] + except KeyError: + if default is self.__marker: + raise + return default + else: + del self[key] + return value + + def popitem(self): + try: + key = next(iter(self)) + except StopIteration: + raise KeyError + value = self[key] + del self[key] + return key, value + + def clear(self): + try: + while True: + self.popitem() + except KeyError: + pass + + def update(*args, **kwds): + if len(args) > 2: + raise TypeError("update() takes at most 2 positional " + "arguments ({} given)".format(len(args))) + elif not args: + raise TypeError("update() takes at least 1 argument (0 given)") + self = args[0] + other = args[1] if len(args) >= 2 else () + + if isinstance(other, Mapping): + for key in other: + self[key] = other[key] + elif hasattr(other, "keys"): + for key in other.keys(): + self[key] = other[key] + else: + for key, value in other: + self[key] = value + for key, value in kwds.items(): + self[key] = value + + def setdefault(self, key, default=None): + try: + return self[key] + except KeyError: + self[key] = default + return default + + collections_abc.Mapping.register(Mapping) + collections_abc.MutableMapping.register(MutableMapping) + +else: + # In Python 3 we can just use MutableMapping directly, because it defines + # __slots__. + MutableMapping = collections_abc.MutableMapping + + +class BaseContainer(object): + + """Base container class.""" + + # Minimizes memory usage and disallows assignment to other attributes. + __slots__ = ['_message_listener', '_values'] + + def __init__(self, message_listener): + """ + Args: + message_listener: A MessageListener implementation. + The RepeatedScalarFieldContainer will call this object's + Modified() method when it is modified. + """ + self._message_listener = message_listener + self._values = [] + + def __getitem__(self, key): + """Retrieves item by the specified key.""" + return self._values[key] + + def __len__(self): + """Returns the number of elements in the container.""" + return len(self._values) + + def __ne__(self, other): + """Checks if another instance isn't equal to this one.""" + # The concrete classes should define __eq__. + return not self == other + + def __hash__(self): + raise TypeError('unhashable object') + + def __repr__(self): + return repr(self._values) + + def sort(self, *args, **kwargs): + # Continue to support the old sort_function keyword argument. + # This is expected to be a rare occurrence, so use LBYL to avoid + # the overhead of actually catching KeyError. + if 'sort_function' in kwargs: + kwargs['cmp'] = kwargs.pop('sort_function') + self._values.sort(*args, **kwargs) + + def reverse(self): + self._values.reverse() + + +collections_abc.MutableSequence.register(BaseContainer) + + +class RepeatedScalarFieldContainer(BaseContainer): + """Simple, type-checked, list-like container for holding repeated scalars.""" + + # Disallows assignment to other attributes. + __slots__ = ['_type_checker'] + + def __init__(self, message_listener, type_checker): + """Args: + + message_listener: A MessageListener implementation. The + RepeatedScalarFieldContainer will call this object's Modified() method + when it is modified. + type_checker: A type_checkers.ValueChecker instance to run on elements + inserted into this container. + """ + super(RepeatedScalarFieldContainer, self).__init__(message_listener) + self._type_checker = type_checker + + def append(self, value): + """Appends an item to the list. Similar to list.append().""" + self._values.append(self._type_checker.CheckValue(value)) + if not self._message_listener.dirty: + self._message_listener.Modified() + + def insert(self, key, value): + """Inserts the item at the specified position. Similar to list.insert().""" + self._values.insert(key, self._type_checker.CheckValue(value)) + if not self._message_listener.dirty: + self._message_listener.Modified() + + def extend(self, elem_seq): + """Extends by appending the given iterable. Similar to list.extend().""" + + if elem_seq is None: + return + try: + elem_seq_iter = iter(elem_seq) + except TypeError: + if not elem_seq: + # silently ignore falsy inputs :-/. + # TODO(ptucker): Deprecate this behavior. b/18413862 + return + raise + + new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter] + if new_values: + self._values.extend(new_values) + self._message_listener.Modified() + + def MergeFrom(self, other): + """Appends the contents of another repeated field of the same type to this + one. We do not check the types of the individual fields. + """ + self._values.extend(other._values) + self._message_listener.Modified() + + def remove(self, elem): + """Removes an item from the list. Similar to list.remove().""" + self._values.remove(elem) + self._message_listener.Modified() + + def pop(self, key=-1): + """Removes and returns an item at a given index. Similar to list.pop().""" + value = self._values[key] + self.__delitem__(key) + return value + + def __setitem__(self, key, value): + """Sets the item on the specified position.""" + if isinstance(key, slice): # PY3 + if key.step is not None: + raise ValueError('Extended slices not supported') + self.__setslice__(key.start, key.stop, value) + else: + self._values[key] = self._type_checker.CheckValue(value) + self._message_listener.Modified() + + def __getslice__(self, start, stop): + """Retrieves the subset of items from between the specified indices.""" + return self._values[start:stop] + + def __setslice__(self, start, stop, values): + """Sets the subset of items from between the specified indices.""" + new_values = [] + for value in values: + new_values.append(self._type_checker.CheckValue(value)) + self._values[start:stop] = new_values + self._message_listener.Modified() + + def __delitem__(self, key): + """Deletes the item at the specified position.""" + del self._values[key] + self._message_listener.Modified() + + def __delslice__(self, start, stop): + """Deletes the subset of items from between the specified indices.""" + del self._values[start:stop] + self._message_listener.Modified() + + def __eq__(self, other): + """Compares the current instance with another one.""" + if self is other: + return True + # Special case for the same type which should be common and fast. + if isinstance(other, self.__class__): + return other._values == self._values + # We are presumably comparing against some other sequence type. + return other == self._values + + +class RepeatedCompositeFieldContainer(BaseContainer): + + """Simple, list-like container for holding repeated composite fields.""" + + # Disallows assignment to other attributes. + __slots__ = ['_message_descriptor'] + + def __init__(self, message_listener, message_descriptor): + """ + Note that we pass in a descriptor instead of the generated directly, + since at the time we construct a _RepeatedCompositeFieldContainer we + haven't yet necessarily initialized the type that will be contained in the + container. + + Args: + message_listener: A MessageListener implementation. + The RepeatedCompositeFieldContainer will call this object's + Modified() method when it is modified. + message_descriptor: A Descriptor instance describing the protocol type + that should be present in this container. We'll use the + _concrete_class field of this descriptor when the client calls add(). + """ + super(RepeatedCompositeFieldContainer, self).__init__(message_listener) + self._message_descriptor = message_descriptor + + def add(self, **kwargs): + """Adds a new element at the end of the list and returns it. Keyword + arguments may be used to initialize the element. + """ + new_element = self._message_descriptor._concrete_class(**kwargs) + new_element._SetListener(self._message_listener) + self._values.append(new_element) + if not self._message_listener.dirty: + self._message_listener.Modified() + return new_element + + def append(self, value): + """Appends one element by copying the message.""" + new_element = self._message_descriptor._concrete_class() + new_element._SetListener(self._message_listener) + new_element.CopyFrom(value) + self._values.append(new_element) + if not self._message_listener.dirty: + self._message_listener.Modified() + + def insert(self, key, value): + """Inserts the item at the specified position by copying.""" + new_element = self._message_descriptor._concrete_class() + new_element._SetListener(self._message_listener) + new_element.CopyFrom(value) + self._values.insert(key, new_element) + if not self._message_listener.dirty: + self._message_listener.Modified() + + def extend(self, elem_seq): + """Extends by appending the given sequence of elements of the same type + + as this one, copying each individual message. + """ + message_class = self._message_descriptor._concrete_class + listener = self._message_listener + values = self._values + for message in elem_seq: + new_element = message_class() + new_element._SetListener(listener) + new_element.MergeFrom(message) + values.append(new_element) + listener.Modified() + + def MergeFrom(self, other): + """Appends the contents of another repeated field of the same type to this + one, copying each individual message. + """ + self.extend(other._values) + + def remove(self, elem): + """Removes an item from the list. Similar to list.remove().""" + self._values.remove(elem) + self._message_listener.Modified() + + def pop(self, key=-1): + """Removes and returns an item at a given index. Similar to list.pop().""" + value = self._values[key] + self.__delitem__(key) + return value + + def __getslice__(self, start, stop): + """Retrieves the subset of items from between the specified indices.""" + return self._values[start:stop] + + def __delitem__(self, key): + """Deletes the item at the specified position.""" + del self._values[key] + self._message_listener.Modified() + + def __delslice__(self, start, stop): + """Deletes the subset of items from between the specified indices.""" + del self._values[start:stop] + self._message_listener.Modified() + + def __eq__(self, other): + """Compares the current instance with another one.""" + if self is other: + return True + if not isinstance(other, self.__class__): + raise TypeError('Can only compare repeated composite fields against ' + 'other repeated composite fields.') + return self._values == other._values + + +class ScalarMap(MutableMapping): + + """Simple, type-checked, dict-like container for holding repeated scalars.""" + + # Disallows assignment to other attributes. + __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener', + '_entry_descriptor'] + + def __init__(self, message_listener, key_checker, value_checker, + entry_descriptor): + """ + Args: + message_listener: A MessageListener implementation. + The ScalarMap will call this object's Modified() method when it + is modified. + key_checker: A type_checkers.ValueChecker instance to run on keys + inserted into this container. + value_checker: A type_checkers.ValueChecker instance to run on values + inserted into this container. + entry_descriptor: The MessageDescriptor of a map entry: key and value. + """ + self._message_listener = message_listener + self._key_checker = key_checker + self._value_checker = value_checker + self._entry_descriptor = entry_descriptor + self._values = {} + + def __getitem__(self, key): + try: + return self._values[key] + except KeyError: + key = self._key_checker.CheckValue(key) + val = self._value_checker.DefaultValue() + self._values[key] = val + return val + + def __contains__(self, item): + # We check the key's type to match the strong-typing flavor of the API. + # Also this makes it easier to match the behavior of the C++ implementation. + self._key_checker.CheckValue(item) + return item in self._values + + # We need to override this explicitly, because our defaultdict-like behavior + # will make the default implementation (from our base class) always insert + # the key. + def get(self, key, default=None): + if key in self: + return self[key] + else: + return default + + def __setitem__(self, key, value): + checked_key = self._key_checker.CheckValue(key) + checked_value = self._value_checker.CheckValue(value) + self._values[checked_key] = checked_value + self._message_listener.Modified() + + def __delitem__(self, key): + del self._values[key] + self._message_listener.Modified() + + def __len__(self): + return len(self._values) + + def __iter__(self): + return iter(self._values) + + def __repr__(self): + return repr(self._values) + + def MergeFrom(self, other): + self._values.update(other._values) + self._message_listener.Modified() + + def InvalidateIterators(self): + # It appears that the only way to reliably invalidate iterators to + # self._values is to ensure that its size changes. + original = self._values + self._values = original.copy() + original[None] = None + + # This is defined in the abstract base, but we can do it much more cheaply. + def clear(self): + self._values.clear() + self._message_listener.Modified() + + def GetEntryClass(self): + return self._entry_descriptor._concrete_class + + +class MessageMap(MutableMapping): + + """Simple, type-checked, dict-like container for with submessage values.""" + + # Disallows assignment to other attributes. + __slots__ = ['_key_checker', '_values', '_message_listener', + '_message_descriptor', '_entry_descriptor'] + + def __init__(self, message_listener, message_descriptor, key_checker, + entry_descriptor): + """ + Args: + message_listener: A MessageListener implementation. + The ScalarMap will call this object's Modified() method when it + is modified. + key_checker: A type_checkers.ValueChecker instance to run on keys + inserted into this container. + value_checker: A type_checkers.ValueChecker instance to run on values + inserted into this container. + entry_descriptor: The MessageDescriptor of a map entry: key and value. + """ + self._message_listener = message_listener + self._message_descriptor = message_descriptor + self._key_checker = key_checker + self._entry_descriptor = entry_descriptor + self._values = {} + + def __getitem__(self, key): + key = self._key_checker.CheckValue(key) + try: + return self._values[key] + except KeyError: + new_element = self._message_descriptor._concrete_class() + new_element._SetListener(self._message_listener) + self._values[key] = new_element + self._message_listener.Modified() + + return new_element + + def get_or_create(self, key): + """get_or_create() is an alias for getitem (ie. map[key]). + + Args: + key: The key to get or create in the map. + + This is useful in cases where you want to be explicit that the call is + mutating the map. This can avoid lint errors for statements like this + that otherwise would appear to be pointless statements: + + msg.my_map[key] + """ + return self[key] + + # We need to override this explicitly, because our defaultdict-like behavior + # will make the default implementation (from our base class) always insert + # the key. + def get(self, key, default=None): + if key in self: + return self[key] + else: + return default + + def __contains__(self, item): + item = self._key_checker.CheckValue(item) + return item in self._values + + def __setitem__(self, key, value): + raise ValueError('May not set values directly, call my_map[key].foo = 5') + + def __delitem__(self, key): + key = self._key_checker.CheckValue(key) + del self._values[key] + self._message_listener.Modified() + + def __len__(self): + return len(self._values) + + def __iter__(self): + return iter(self._values) + + def __repr__(self): + return repr(self._values) + + def MergeFrom(self, other): + # pylint: disable=protected-access + for key in other._values: + # According to documentation: "When parsing from the wire or when merging, + # if there are duplicate map keys the last key seen is used". + if key in self: + del self[key] + self[key].CopyFrom(other[key]) + # self._message_listener.Modified() not required here, because + # mutations to submessages already propagate. + + def InvalidateIterators(self): + # It appears that the only way to reliably invalidate iterators to + # self._values is to ensure that its size changes. + original = self._values + self._values = original.copy() + original[None] = None + + # This is defined in the abstract base, but we can do it much more cheaply. + def clear(self): + self._values.clear() + self._message_listener.Modified() + + def GetEntryClass(self): + return self._entry_descriptor._concrete_class + + +class _UnknownField(object): + + """A parsed unknown field.""" + + # Disallows assignment to other attributes. + __slots__ = ['_field_number', '_wire_type', '_data'] + + def __init__(self, field_number, wire_type, data): + self._field_number = field_number + self._wire_type = wire_type + self._data = data + return + + def __lt__(self, other): + # pylint: disable=protected-access + return self._field_number < other._field_number + + def __eq__(self, other): + if self is other: + return True + # pylint: disable=protected-access + return (self._field_number == other._field_number and + self._wire_type == other._wire_type and + self._data == other._data) + + +class UnknownFieldRef(object): + + def __init__(self, parent, index): + self._parent = parent + self._index = index + return + + def _check_valid(self): + if not self._parent: + raise ValueError('UnknownField does not exist. ' + 'The parent message might be cleared.') + if self._index >= len(self._parent): + raise ValueError('UnknownField does not exist. ' + 'The parent message might be cleared.') + + @property + def field_number(self): + self._check_valid() + # pylint: disable=protected-access + return self._parent._internal_get(self._index)._field_number + + @property + def wire_type(self): + self._check_valid() + # pylint: disable=protected-access + return self._parent._internal_get(self._index)._wire_type + + @property + def data(self): + self._check_valid() + # pylint: disable=protected-access + return self._parent._internal_get(self._index)._data + + +class UnknownFieldSet(object): + + """UnknownField container""" + + # Disallows assignment to other attributes. + __slots__ = ['_values'] + + def __init__(self): + self._values = [] + + def __getitem__(self, index): + if self._values is None: + raise ValueError('UnknownFields does not exist. ' + 'The parent message might be cleared.') + size = len(self._values) + if index < 0: + index += size + if index < 0 or index >= size: + raise IndexError('index %d out of range'.index) + + return UnknownFieldRef(self, index) + + def _internal_get(self, index): + return self._values[index] + + def __len__(self): + if self._values is None: + raise ValueError('UnknownFields does not exist. ' + 'The parent message might be cleared.') + return len(self._values) + + def _add(self, field_number, wire_type, data): + unknown_field = _UnknownField(field_number, wire_type, data) + self._values.append(unknown_field) + return unknown_field + + def __iter__(self): + for i in range(len(self)): + yield UnknownFieldRef(self, i) + + def _extend(self, other): + if other is None: + return + # pylint: disable=protected-access + self._values.extend(other._values) + + def __eq__(self, other): + if self is other: + return True + # Sort unknown fields because their order shouldn't + # affect equality test. + values = list(self._values) + if other is None: + return not values + values.sort() + # pylint: disable=protected-access + other_values = sorted(other._values) + return values == other_values + + def _clear(self): + for value in self._values: + # pylint: disable=protected-access + if isinstance(value._data, UnknownFieldSet): + value._data._clear() # pylint: disable=protected-access + self._values = None diff --git a/contrib/python/protobuf/py3/google/protobuf/internal/decoder.py b/contrib/python/protobuf/py3/google/protobuf/internal/decoder.py new file mode 100644 index 0000000000..6804986b6e --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/internal/decoder.py @@ -0,0 +1,1057 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Code for decoding protocol buffer primitives. + +This code is very similar to encoder.py -- read the docs for that module first. + +A "decoder" is a function with the signature: + Decode(buffer, pos, end, message, field_dict) +The arguments are: + buffer: The string containing the encoded message. + pos: The current position in the string. + end: The position in the string where the current message ends. May be + less than len(buffer) if we're reading a sub-message. + message: The message object into which we're parsing. + field_dict: message._fields (avoids a hashtable lookup). +The decoder reads the field and stores it into field_dict, returning the new +buffer position. A decoder for a repeated field may proactively decode all of +the elements of that field, if they appear consecutively. + +Note that decoders may throw any of the following: + IndexError: Indicates a truncated message. + struct.error: Unpacking of a fixed-width field failed. + message.DecodeError: Other errors. + +Decoders are expected to raise an exception if they are called with pos > end. +This allows callers to be lax about bounds checking: it's fineto read past +"end" as long as you are sure that someone else will notice and throw an +exception later on. + +Something up the call stack is expected to catch IndexError and struct.error +and convert them to message.DecodeError. + +Decoders are constructed using decoder constructors with the signature: + MakeDecoder(field_number, is_repeated, is_packed, key, new_default) +The arguments are: + field_number: The field number of the field we want to decode. + is_repeated: Is the field a repeated field? (bool) + is_packed: Is the field a packed field? (bool) + key: The key to use when looking up the field within field_dict. + (This is actually the FieldDescriptor but nothing in this + file should depend on that.) + new_default: A function which takes a message object as a parameter and + returns a new instance of the default value for this field. + (This is called for repeated fields and sub-messages, when an + instance does not already exist.) + +As with encoders, we define a decoder constructor for every type of field. +Then, for every field of every message class we construct an actual decoder. +That decoder goes into a dict indexed by tag, so when we decode a message +we repeatedly read a tag, look up the corresponding decoder, and invoke it. +""" + +__author__ = 'kenton@google.com (Kenton Varda)' + +import struct +import sys +import six + +_UCS2_MAXUNICODE = 65535 +if six.PY3: + long = int +else: + import re # pylint: disable=g-import-not-at-top + _SURROGATE_PATTERN = re.compile(six.u(r'[\ud800-\udfff]')) + +from google.protobuf.internal import containers +from google.protobuf.internal import encoder +from google.protobuf.internal import wire_format +from google.protobuf import message + + +# This will overflow and thus become IEEE-754 "infinity". We would use +# "float('inf')" but it doesn't work on Windows pre-Python-2.6. +_POS_INF = 1e10000 +_NEG_INF = -_POS_INF +_NAN = _POS_INF * 0 + + +# This is not for optimization, but rather to avoid conflicts with local +# variables named "message". +_DecodeError = message.DecodeError + + +def _VarintDecoder(mask, result_type): + """Return an encoder for a basic varint value (does not include tag). + + Decoded values will be bitwise-anded with the given mask before being + returned, e.g. to limit them to 32 bits. The returned decoder does not + take the usual "end" parameter -- the caller is expected to do bounds checking + after the fact (often the caller can defer such checking until later). The + decoder returns a (value, new_pos) pair. + """ + + def DecodeVarint(buffer, pos): + result = 0 + shift = 0 + while 1: + b = six.indexbytes(buffer, pos) + result |= ((b & 0x7f) << shift) + pos += 1 + if not (b & 0x80): + result &= mask + result = result_type(result) + return (result, pos) + shift += 7 + if shift >= 64: + raise _DecodeError('Too many bytes when decoding varint.') + return DecodeVarint + + +def _SignedVarintDecoder(bits, result_type): + """Like _VarintDecoder() but decodes signed values.""" + + signbit = 1 << (bits - 1) + mask = (1 << bits) - 1 + + def DecodeVarint(buffer, pos): + result = 0 + shift = 0 + while 1: + b = six.indexbytes(buffer, pos) + result |= ((b & 0x7f) << shift) + pos += 1 + if not (b & 0x80): + result &= mask + result = (result ^ signbit) - signbit + result = result_type(result) + return (result, pos) + shift += 7 + if shift >= 64: + raise _DecodeError('Too many bytes when decoding varint.') + return DecodeVarint + +# We force 32-bit values to int and 64-bit values to long to make +# alternate implementations where the distinction is more significant +# (e.g. the C++ implementation) simpler. + +_DecodeVarint = _VarintDecoder((1 << 64) - 1, long) +_DecodeSignedVarint = _SignedVarintDecoder(64, long) + +# Use these versions for values which must be limited to 32 bits. +_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) +_DecodeSignedVarint32 = _SignedVarintDecoder(32, int) + + +def ReadTag(buffer, pos): + """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple. + + We return the raw bytes of the tag rather than decoding them. The raw + bytes can then be used to look up the proper decoder. This effectively allows + us to trade some work that would be done in pure-python (decoding a varint) + for work that is done in C (searching for a byte string in a hash table). + In a low-level language it would be much cheaper to decode the varint and + use that, but not in Python. + + Args: + buffer: memoryview object of the encoded bytes + pos: int of the current position to start from + + Returns: + Tuple[bytes, int] of the tag data and new position. + """ + start = pos + while six.indexbytes(buffer, pos) & 0x80: + pos += 1 + pos += 1 + + tag_bytes = buffer[start:pos].tobytes() + return tag_bytes, pos + + +# -------------------------------------------------------------------- + + +def _SimpleDecoder(wire_type, decode_value): + """Return a constructor for a decoder for fields of a particular type. + + Args: + wire_type: The field's wire type. + decode_value: A function which decodes an individual value, e.g. + _DecodeVarint() + """ + + def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default, + clear_if_default=False): + if is_packed: + local_DecodeVarint = _DecodeVarint + def DecodePackedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + (endpoint, pos) = local_DecodeVarint(buffer, pos) + endpoint += pos + if endpoint > end: + raise _DecodeError('Truncated message.') + while pos < endpoint: + (element, pos) = decode_value(buffer, pos) + value.append(element) + if pos > endpoint: + del value[-1] # Discard corrupt value. + raise _DecodeError('Packed element was truncated.') + return pos + return DecodePackedField + elif is_repeated: + tag_bytes = encoder.TagBytes(field_number, wire_type) + tag_len = len(tag_bytes) + def DecodeRepeatedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + (element, new_pos) = decode_value(buffer, pos) + value.append(element) + # Predict that the next tag is another copy of the same repeated + # field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos >= end: + # Prediction failed. Return. + if new_pos > end: + raise _DecodeError('Truncated message.') + return new_pos + return DecodeRepeatedField + else: + def DecodeField(buffer, pos, end, message, field_dict): + (new_value, pos) = decode_value(buffer, pos) + if pos > end: + raise _DecodeError('Truncated message.') + if clear_if_default and not new_value: + field_dict.pop(key, None) + else: + field_dict[key] = new_value + return pos + return DecodeField + + return SpecificDecoder + + +def _ModifiedDecoder(wire_type, decode_value, modify_value): + """Like SimpleDecoder but additionally invokes modify_value on every value + before storing it. Usually modify_value is ZigZagDecode. + """ + + # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but + # not enough to make a significant difference. + + def InnerDecode(buffer, pos): + (result, new_pos) = decode_value(buffer, pos) + return (modify_value(result), new_pos) + return _SimpleDecoder(wire_type, InnerDecode) + + +def _StructPackDecoder(wire_type, format): + """Return a constructor for a decoder for a fixed-width field. + + Args: + wire_type: The field's wire type. + format: The format string to pass to struct.unpack(). + """ + + value_size = struct.calcsize(format) + local_unpack = struct.unpack + + # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but + # not enough to make a significant difference. + + # Note that we expect someone up-stack to catch struct.error and convert + # it to _DecodeError -- this way we don't have to set up exception- + # handling blocks every time we parse one value. + + def InnerDecode(buffer, pos): + new_pos = pos + value_size + result = local_unpack(format, buffer[pos:new_pos])[0] + return (result, new_pos) + return _SimpleDecoder(wire_type, InnerDecode) + + +def _FloatDecoder(): + """Returns a decoder for a float field. + + This code works around a bug in struct.unpack for non-finite 32-bit + floating-point values. + """ + + local_unpack = struct.unpack + + def InnerDecode(buffer, pos): + """Decode serialized float to a float and new position. + + Args: + buffer: memoryview of the serialized bytes + pos: int, position in the memory view to start at. + + Returns: + Tuple[float, int] of the deserialized float value and new position + in the serialized data. + """ + # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign + # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand. + new_pos = pos + 4 + float_bytes = buffer[pos:new_pos].tobytes() + + # If this value has all its exponent bits set, then it's non-finite. + # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. + # To avoid that, we parse it specially. + if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'): + # If at least one significand bit is set... + if float_bytes[0:3] != b'\x00\x00\x80': + return (_NAN, new_pos) + # If sign bit is set... + if float_bytes[3:4] == b'\xFF': + return (_NEG_INF, new_pos) + return (_POS_INF, new_pos) + + # Note that we expect someone up-stack to catch struct.error and convert + # it to _DecodeError -- this way we don't have to set up exception- + # handling blocks every time we parse one value. + result = local_unpack('<f', float_bytes)[0] + return (result, new_pos) + return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode) + + +def _DoubleDecoder(): + """Returns a decoder for a double field. + + This code works around a bug in struct.unpack for not-a-number. + """ + + local_unpack = struct.unpack + + def InnerDecode(buffer, pos): + """Decode serialized double to a double and new position. + + Args: + buffer: memoryview of the serialized bytes. + pos: int, position in the memory view to start at. + + Returns: + Tuple[float, int] of the decoded double value and new position + in the serialized data. + """ + # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign + # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand. + new_pos = pos + 8 + double_bytes = buffer[pos:new_pos].tobytes() + + # If this value has all its exponent bits set and at least one significand + # bit set, it's not a number. In Python 2.4, struct.unpack will treat it + # as inf or -inf. To avoid that, we treat it specially. + if ((double_bytes[7:8] in b'\x7F\xFF') + and (double_bytes[6:7] >= b'\xF0') + and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): + return (_NAN, new_pos) + + # Note that we expect someone up-stack to catch struct.error and convert + # it to _DecodeError -- this way we don't have to set up exception- + # handling blocks every time we parse one value. + result = local_unpack('<d', double_bytes)[0] + return (result, new_pos) + return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) + + +def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, + clear_if_default=False): + """Returns a decoder for enum field.""" + enum_type = key.enum_type + if is_packed: + local_DecodeVarint = _DecodeVarint + def DecodePackedField(buffer, pos, end, message, field_dict): + """Decode serialized packed enum to its value and a new position. + + Args: + buffer: memoryview of the serialized bytes. + pos: int, position in the memory view to start at. + end: int, end position of serialized data + message: Message object to store unknown fields in + field_dict: Map[Descriptor, Any] to store decoded values in. + + Returns: + int, new position in serialized data. + """ + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + (endpoint, pos) = local_DecodeVarint(buffer, pos) + endpoint += pos + if endpoint > end: + raise _DecodeError('Truncated message.') + while pos < endpoint: + value_start_pos = pos + (element, pos) = _DecodeSignedVarint32(buffer, pos) + # pylint: disable=protected-access + if element in enum_type.values_by_number: + value.append(element) + else: + if not message._unknown_fields: + message._unknown_fields = [] + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_VARINT) + + message._unknown_fields.append( + (tag_bytes, buffer[value_start_pos:pos].tobytes())) + if message._unknown_field_set is None: + message._unknown_field_set = containers.UnknownFieldSet() + message._unknown_field_set._add( + field_number, wire_format.WIRETYPE_VARINT, element) + # pylint: enable=protected-access + if pos > endpoint: + if element in enum_type.values_by_number: + del value[-1] # Discard corrupt value. + else: + del message._unknown_fields[-1] + # pylint: disable=protected-access + del message._unknown_field_set._values[-1] + # pylint: enable=protected-access + raise _DecodeError('Packed element was truncated.') + return pos + return DecodePackedField + elif is_repeated: + tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) + tag_len = len(tag_bytes) + def DecodeRepeatedField(buffer, pos, end, message, field_dict): + """Decode serialized repeated enum to its value and a new position. + + Args: + buffer: memoryview of the serialized bytes. + pos: int, position in the memory view to start at. + end: int, end position of serialized data + message: Message object to store unknown fields in + field_dict: Map[Descriptor, Any] to store decoded values in. + + Returns: + int, new position in serialized data. + """ + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + (element, new_pos) = _DecodeSignedVarint32(buffer, pos) + # pylint: disable=protected-access + if element in enum_type.values_by_number: + value.append(element) + else: + if not message._unknown_fields: + message._unknown_fields = [] + message._unknown_fields.append( + (tag_bytes, buffer[pos:new_pos].tobytes())) + if message._unknown_field_set is None: + message._unknown_field_set = containers.UnknownFieldSet() + message._unknown_field_set._add( + field_number, wire_format.WIRETYPE_VARINT, element) + # pylint: enable=protected-access + # Predict that the next tag is another copy of the same repeated + # field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos >= end: + # Prediction failed. Return. + if new_pos > end: + raise _DecodeError('Truncated message.') + return new_pos + return DecodeRepeatedField + else: + def DecodeField(buffer, pos, end, message, field_dict): + """Decode serialized repeated enum to its value and a new position. + + Args: + buffer: memoryview of the serialized bytes. + pos: int, position in the memory view to start at. + end: int, end position of serialized data + message: Message object to store unknown fields in + field_dict: Map[Descriptor, Any] to store decoded values in. + + Returns: + int, new position in serialized data. + """ + value_start_pos = pos + (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) + if pos > end: + raise _DecodeError('Truncated message.') + if clear_if_default and not enum_value: + field_dict.pop(key, None) + return pos + # pylint: disable=protected-access + if enum_value in enum_type.values_by_number: + field_dict[key] = enum_value + else: + if not message._unknown_fields: + message._unknown_fields = [] + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_VARINT) + message._unknown_fields.append( + (tag_bytes, buffer[value_start_pos:pos].tobytes())) + if message._unknown_field_set is None: + message._unknown_field_set = containers.UnknownFieldSet() + message._unknown_field_set._add( + field_number, wire_format.WIRETYPE_VARINT, enum_value) + # pylint: enable=protected-access + return pos + return DecodeField + + +# -------------------------------------------------------------------- + + +Int32Decoder = _SimpleDecoder( + wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) + +Int64Decoder = _SimpleDecoder( + wire_format.WIRETYPE_VARINT, _DecodeSignedVarint) + +UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32) +UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint) + +SInt32Decoder = _ModifiedDecoder( + wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode) +SInt64Decoder = _ModifiedDecoder( + wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode) + +# Note that Python conveniently guarantees that when using the '<' prefix on +# formats, they will also have the same size across all platforms (as opposed +# to without the prefix, where their sizes depend on the C compiler's basic +# type sizes). +Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I') +Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q') +SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i') +SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q') +FloatDecoder = _FloatDecoder() +DoubleDecoder = _DoubleDecoder() + +BoolDecoder = _ModifiedDecoder( + wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) + + +def StringDecoder(field_number, is_repeated, is_packed, key, new_default, + is_strict_utf8=False, clear_if_default=False): + """Returns a decoder for a string field.""" + + local_DecodeVarint = _DecodeVarint + local_unicode = six.text_type + + def _ConvertToUnicode(memview): + """Convert byte to unicode.""" + byte_str = memview.tobytes() + try: + value = local_unicode(byte_str, 'utf-8') + except UnicodeDecodeError as e: + # add more information to the error message and re-raise it. + e.reason = '%s in field: %s' % (e, key.full_name) + raise + + if is_strict_utf8 and six.PY2 and sys.maxunicode > _UCS2_MAXUNICODE: + # Only do the check for python2 ucs4 when is_strict_utf8 enabled + if _SURROGATE_PATTERN.search(value): + reason = ('String field %s contains invalid UTF-8 data when parsing' + 'a protocol buffer: surrogates not allowed. Use' + 'the bytes type if you intend to send raw bytes.') % ( + key.full_name) + raise message.DecodeError(reason) + + return value + + assert not is_packed + if is_repeated: + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_LENGTH_DELIMITED) + tag_len = len(tag_bytes) + def DecodeRepeatedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated string.') + value.append(_ConvertToUnicode(buffer[pos:new_pos])) + # Predict that the next tag is another copy of the same repeated field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos == end: + # Prediction failed. Return. + return new_pos + return DecodeRepeatedField + else: + def DecodeField(buffer, pos, end, message, field_dict): + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated string.') + if clear_if_default and not size: + field_dict.pop(key, None) + else: + field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) + return new_pos + return DecodeField + + +def BytesDecoder(field_number, is_repeated, is_packed, key, new_default, + clear_if_default=False): + """Returns a decoder for a bytes field.""" + + local_DecodeVarint = _DecodeVarint + + assert not is_packed + if is_repeated: + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_LENGTH_DELIMITED) + tag_len = len(tag_bytes) + def DecodeRepeatedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated string.') + value.append(buffer[pos:new_pos].tobytes()) + # Predict that the next tag is another copy of the same repeated field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos == end: + # Prediction failed. Return. + return new_pos + return DecodeRepeatedField + else: + def DecodeField(buffer, pos, end, message, field_dict): + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated string.') + if clear_if_default and not size: + field_dict.pop(key, None) + else: + field_dict[key] = buffer[pos:new_pos].tobytes() + return new_pos + return DecodeField + + +def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): + """Returns a decoder for a group field.""" + + end_tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_END_GROUP) + end_tag_len = len(end_tag_bytes) + + assert not is_packed + if is_repeated: + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_START_GROUP) + tag_len = len(tag_bytes) + def DecodeRepeatedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + # Read sub-message. + pos = value.add()._InternalParse(buffer, pos, end) + # Read end tag. + new_pos = pos+end_tag_len + if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: + raise _DecodeError('Missing group end tag.') + # Predict that the next tag is another copy of the same repeated field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos == end: + # Prediction failed. Return. + return new_pos + return DecodeRepeatedField + else: + def DecodeField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + # Read sub-message. + pos = value._InternalParse(buffer, pos, end) + # Read end tag. + new_pos = pos+end_tag_len + if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: + raise _DecodeError('Missing group end tag.') + return new_pos + return DecodeField + + +def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): + """Returns a decoder for a message field.""" + + local_DecodeVarint = _DecodeVarint + + assert not is_packed + if is_repeated: + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_LENGTH_DELIMITED) + tag_len = len(tag_bytes) + def DecodeRepeatedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + # Read length. + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated message.') + # Read sub-message. + if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: + # The only reason _InternalParse would return early is if it + # encountered an end-group tag. + raise _DecodeError('Unexpected end-group tag.') + # Predict that the next tag is another copy of the same repeated field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos == end: + # Prediction failed. Return. + return new_pos + return DecodeRepeatedField + else: + def DecodeField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + # Read length. + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated message.') + # Read sub-message. + if value._InternalParse(buffer, pos, new_pos) != new_pos: + # The only reason _InternalParse would return early is if it encountered + # an end-group tag. + raise _DecodeError('Unexpected end-group tag.') + return new_pos + return DecodeField + + +# -------------------------------------------------------------------- + +MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP) + +def MessageSetItemDecoder(descriptor): + """Returns a decoder for a MessageSet item. + + The parameter is the message Descriptor. + + The message set message looks like this: + message MessageSet { + repeated group Item = 1 { + required int32 type_id = 2; + required string message = 3; + } + } + """ + + type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) + message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) + item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) + + local_ReadTag = ReadTag + local_DecodeVarint = _DecodeVarint + local_SkipField = SkipField + + def DecodeItem(buffer, pos, end, message, field_dict): + """Decode serialized message set to its value and new position. + + Args: + buffer: memoryview of the serialized bytes. + pos: int, position in the memory view to start at. + end: int, end position of serialized data + message: Message object to store unknown fields in + field_dict: Map[Descriptor, Any] to store decoded values in. + + Returns: + int, new position in serialized data. + """ + message_set_item_start = pos + type_id = -1 + message_start = -1 + message_end = -1 + + # Technically, type_id and message can appear in any order, so we need + # a little loop here. + while 1: + (tag_bytes, pos) = local_ReadTag(buffer, pos) + if tag_bytes == type_id_tag_bytes: + (type_id, pos) = local_DecodeVarint(buffer, pos) + elif tag_bytes == message_tag_bytes: + (size, message_start) = local_DecodeVarint(buffer, pos) + pos = message_end = message_start + size + elif tag_bytes == item_end_tag_bytes: + break + else: + pos = SkipField(buffer, pos, end, tag_bytes) + if pos == -1: + raise _DecodeError('Missing group end tag.') + + if pos > end: + raise _DecodeError('Truncated message.') + + if type_id == -1: + raise _DecodeError('MessageSet item missing type_id.') + if message_start == -1: + raise _DecodeError('MessageSet item missing message.') + + extension = message.Extensions._FindExtensionByNumber(type_id) + # pylint: disable=protected-access + if extension is not None: + value = field_dict.get(extension) + if value is None: + message_type = extension.message_type + if not hasattr(message_type, '_concrete_class'): + # pylint: disable=protected-access + message._FACTORY.GetPrototype(message_type) + value = field_dict.setdefault( + extension, message_type._concrete_class()) + if value._InternalParse(buffer, message_start,message_end) != message_end: + # The only reason _InternalParse would return early is if it encountered + # an end-group tag. + raise _DecodeError('Unexpected end-group tag.') + else: + if not message._unknown_fields: + message._unknown_fields = [] + message._unknown_fields.append( + (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes())) + if message._unknown_field_set is None: + message._unknown_field_set = containers.UnknownFieldSet() + message._unknown_field_set._add( + type_id, + wire_format.WIRETYPE_LENGTH_DELIMITED, + buffer[message_start:message_end].tobytes()) + # pylint: enable=protected-access + + return pos + + return DecodeItem + +# -------------------------------------------------------------------- + +def MapDecoder(field_descriptor, new_default, is_message_map): + """Returns a decoder for a map field.""" + + key = field_descriptor + tag_bytes = encoder.TagBytes(field_descriptor.number, + wire_format.WIRETYPE_LENGTH_DELIMITED) + tag_len = len(tag_bytes) + local_DecodeVarint = _DecodeVarint + # Can't read _concrete_class yet; might not be initialized. + message_type = field_descriptor.message_type + + def DecodeMap(buffer, pos, end, message, field_dict): + submsg = message_type._concrete_class() + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + # Read length. + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated message.') + # Read sub-message. + submsg.Clear() + if submsg._InternalParse(buffer, pos, new_pos) != new_pos: + # The only reason _InternalParse would return early is if it + # encountered an end-group tag. + raise _DecodeError('Unexpected end-group tag.') + + if is_message_map: + value[submsg.key].CopyFrom(submsg.value) + else: + value[submsg.key] = submsg.value + + # Predict that the next tag is another copy of the same repeated field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos == end: + # Prediction failed. Return. + return new_pos + + return DecodeMap + +# -------------------------------------------------------------------- +# Optimization is not as heavy here because calls to SkipField() are rare, +# except for handling end-group tags. + +def _SkipVarint(buffer, pos, end): + """Skip a varint value. Returns the new position.""" + # Previously ord(buffer[pos]) raised IndexError when pos is out of range. + # With this code, ord(b'') raises TypeError. Both are handled in + # python_message.py to generate a 'Truncated message' error. + while ord(buffer[pos:pos+1].tobytes()) & 0x80: + pos += 1 + pos += 1 + if pos > end: + raise _DecodeError('Truncated message.') + return pos + +def _SkipFixed64(buffer, pos, end): + """Skip a fixed64 value. Returns the new position.""" + + pos += 8 + if pos > end: + raise _DecodeError('Truncated message.') + return pos + + +def _DecodeFixed64(buffer, pos): + """Decode a fixed64.""" + new_pos = pos + 8 + return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos) + + +def _SkipLengthDelimited(buffer, pos, end): + """Skip a length-delimited value. Returns the new position.""" + + (size, pos) = _DecodeVarint(buffer, pos) + pos += size + if pos > end: + raise _DecodeError('Truncated message.') + return pos + + +def _SkipGroup(buffer, pos, end): + """Skip sub-group. Returns the new position.""" + + while 1: + (tag_bytes, pos) = ReadTag(buffer, pos) + new_pos = SkipField(buffer, pos, end, tag_bytes) + if new_pos == -1: + return pos + pos = new_pos + + +def _DecodeUnknownFieldSet(buffer, pos, end_pos=None): + """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position.""" + + unknown_field_set = containers.UnknownFieldSet() + while end_pos is None or pos < end_pos: + (tag_bytes, pos) = ReadTag(buffer, pos) + (tag, _) = _DecodeVarint(tag_bytes, 0) + field_number, wire_type = wire_format.UnpackTag(tag) + if wire_type == wire_format.WIRETYPE_END_GROUP: + break + (data, pos) = _DecodeUnknownField(buffer, pos, wire_type) + # pylint: disable=protected-access + unknown_field_set._add(field_number, wire_type, data) + + return (unknown_field_set, pos) + + +def _DecodeUnknownField(buffer, pos, wire_type): + """Decode a unknown field. Returns the UnknownField and new position.""" + + if wire_type == wire_format.WIRETYPE_VARINT: + (data, pos) = _DecodeVarint(buffer, pos) + elif wire_type == wire_format.WIRETYPE_FIXED64: + (data, pos) = _DecodeFixed64(buffer, pos) + elif wire_type == wire_format.WIRETYPE_FIXED32: + (data, pos) = _DecodeFixed32(buffer, pos) + elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: + (size, pos) = _DecodeVarint(buffer, pos) + data = buffer[pos:pos+size].tobytes() + pos += size + elif wire_type == wire_format.WIRETYPE_START_GROUP: + (data, pos) = _DecodeUnknownFieldSet(buffer, pos) + elif wire_type == wire_format.WIRETYPE_END_GROUP: + return (0, -1) + else: + raise _DecodeError('Wrong wire type in tag.') + + return (data, pos) + + +def _EndGroup(buffer, pos, end): + """Skipping an END_GROUP tag returns -1 to tell the parent loop to break.""" + + return -1 + + +def _SkipFixed32(buffer, pos, end): + """Skip a fixed32 value. Returns the new position.""" + + pos += 4 + if pos > end: + raise _DecodeError('Truncated message.') + return pos + + +def _DecodeFixed32(buffer, pos): + """Decode a fixed32.""" + + new_pos = pos + 4 + return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos) + + +def _RaiseInvalidWireType(buffer, pos, end): + """Skip function for unknown wire types. Raises an exception.""" + + raise _DecodeError('Tag had invalid wire type.') + +def _FieldSkipper(): + """Constructs the SkipField function.""" + + WIRETYPE_TO_SKIPPER = [ + _SkipVarint, + _SkipFixed64, + _SkipLengthDelimited, + _SkipGroup, + _EndGroup, + _SkipFixed32, + _RaiseInvalidWireType, + _RaiseInvalidWireType, + ] + + wiretype_mask = wire_format.TAG_TYPE_MASK + + def SkipField(buffer, pos, end, tag_bytes): + """Skips a field with the specified tag. + + |pos| should point to the byte immediately after the tag. + + Returns: + The new position (after the tag value), or -1 if the tag is an end-group + tag (in which case the calling loop should break). + """ + + # The wire type is always in the first byte since varints are little-endian. + wire_type = ord(tag_bytes[0:1]) & wiretype_mask + return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) + + return SkipField + +SkipField = _FieldSkipper() diff --git a/contrib/python/protobuf/py3/google/protobuf/internal/encoder.py b/contrib/python/protobuf/py3/google/protobuf/internal/encoder.py new file mode 100644 index 0000000000..0c016f3cfa --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/internal/encoder.py @@ -0,0 +1,830 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Code for encoding protocol message primitives. + +Contains the logic for encoding every logical protocol field type +into one of the 5 physical wire types. + +This code is designed to push the Python interpreter's performance to the +limits. + +The basic idea is that at startup time, for every field (i.e. every +FieldDescriptor) we construct two functions: a "sizer" and an "encoder". The +sizer takes a value of this field's type and computes its byte size. The +encoder takes a writer function and a value. It encodes the value into byte +strings and invokes the writer function to write those strings. Typically the +writer function is the write() method of a BytesIO. + +We try to do as much work as possible when constructing the writer and the +sizer rather than when calling them. In particular: +* We copy any needed global functions to local variables, so that we do not need + to do costly global table lookups at runtime. +* Similarly, we try to do any attribute lookups at startup time if possible. +* Every field's tag is encoded to bytes at startup, since it can't change at + runtime. +* Whatever component of the field size we can compute at startup, we do. +* We *avoid* sharing code if doing so would make the code slower and not sharing + does not burden us too much. For example, encoders for repeated fields do + not just call the encoders for singular fields in a loop because this would + add an extra function call overhead for every loop iteration; instead, we + manually inline the single-value encoder into the loop. +* If a Python function lacks a return statement, Python actually generates + instructions to pop the result of the last statement off the stack, push + None onto the stack, and then return that. If we really don't care what + value is returned, then we can save two instructions by returning the + result of the last statement. It looks funny but it helps. +* We assume that type and bounds checking has happened at a higher level. +""" + +__author__ = 'kenton@google.com (Kenton Varda)' + +import struct + +import six + +from google.protobuf.internal import wire_format + + +# This will overflow and thus become IEEE-754 "infinity". We would use +# "float('inf')" but it doesn't work on Windows pre-Python-2.6. +_POS_INF = 1e10000 +_NEG_INF = -_POS_INF + + +def _VarintSize(value): + """Compute the size of a varint value.""" + if value <= 0x7f: return 1 + if value <= 0x3fff: return 2 + if value <= 0x1fffff: return 3 + if value <= 0xfffffff: return 4 + if value <= 0x7ffffffff: return 5 + if value <= 0x3ffffffffff: return 6 + if value <= 0x1ffffffffffff: return 7 + if value <= 0xffffffffffffff: return 8 + if value <= 0x7fffffffffffffff: return 9 + return 10 + + +def _SignedVarintSize(value): + """Compute the size of a signed varint value.""" + if value < 0: return 10 + if value <= 0x7f: return 1 + if value <= 0x3fff: return 2 + if value <= 0x1fffff: return 3 + if value <= 0xfffffff: return 4 + if value <= 0x7ffffffff: return 5 + if value <= 0x3ffffffffff: return 6 + if value <= 0x1ffffffffffff: return 7 + if value <= 0xffffffffffffff: return 8 + if value <= 0x7fffffffffffffff: return 9 + return 10 + + +def _TagSize(field_number): + """Returns the number of bytes required to serialize a tag with this field + number.""" + # Just pass in type 0, since the type won't affect the tag+type size. + return _VarintSize(wire_format.PackTag(field_number, 0)) + + +# -------------------------------------------------------------------- +# In this section we define some generic sizers. Each of these functions +# takes parameters specific to a particular field type, e.g. int32 or fixed64. +# It returns another function which in turn takes parameters specific to a +# particular field, e.g. the field number and whether it is repeated or packed. +# Look at the next section to see how these are used. + + +def _SimpleSizer(compute_value_size): + """A sizer which uses the function compute_value_size to compute the size of + each value. Typically compute_value_size is _VarintSize.""" + + def SpecificSizer(field_number, is_repeated, is_packed): + tag_size = _TagSize(field_number) + if is_packed: + local_VarintSize = _VarintSize + def PackedFieldSize(value): + result = 0 + for element in value: + result += compute_value_size(element) + return result + local_VarintSize(result) + tag_size + return PackedFieldSize + elif is_repeated: + def RepeatedFieldSize(value): + result = tag_size * len(value) + for element in value: + result += compute_value_size(element) + return result + return RepeatedFieldSize + else: + def FieldSize(value): + return tag_size + compute_value_size(value) + return FieldSize + + return SpecificSizer + + +def _ModifiedSizer(compute_value_size, modify_value): + """Like SimpleSizer, but modify_value is invoked on each value before it is + passed to compute_value_size. modify_value is typically ZigZagEncode.""" + + def SpecificSizer(field_number, is_repeated, is_packed): + tag_size = _TagSize(field_number) + if is_packed: + local_VarintSize = _VarintSize + def PackedFieldSize(value): + result = 0 + for element in value: + result += compute_value_size(modify_value(element)) + return result + local_VarintSize(result) + tag_size + return PackedFieldSize + elif is_repeated: + def RepeatedFieldSize(value): + result = tag_size * len(value) + for element in value: + result += compute_value_size(modify_value(element)) + return result + return RepeatedFieldSize + else: + def FieldSize(value): + return tag_size + compute_value_size(modify_value(value)) + return FieldSize + + return SpecificSizer + + +def _FixedSizer(value_size): + """Like _SimpleSizer except for a fixed-size field. The input is the size + of one value.""" + + def SpecificSizer(field_number, is_repeated, is_packed): + tag_size = _TagSize(field_number) + if is_packed: + local_VarintSize = _VarintSize + def PackedFieldSize(value): + result = len(value) * value_size + return result + local_VarintSize(result) + tag_size + return PackedFieldSize + elif is_repeated: + element_size = value_size + tag_size + def RepeatedFieldSize(value): + return len(value) * element_size + return RepeatedFieldSize + else: + field_size = value_size + tag_size + def FieldSize(value): + return field_size + return FieldSize + + return SpecificSizer + + +# ==================================================================== +# Here we declare a sizer constructor for each field type. Each "sizer +# constructor" is a function that takes (field_number, is_repeated, is_packed) +# as parameters and returns a sizer, which in turn takes a field value as +# a parameter and returns its encoded size. + + +Int32Sizer = Int64Sizer = EnumSizer = _SimpleSizer(_SignedVarintSize) + +UInt32Sizer = UInt64Sizer = _SimpleSizer(_VarintSize) + +SInt32Sizer = SInt64Sizer = _ModifiedSizer( + _SignedVarintSize, wire_format.ZigZagEncode) + +Fixed32Sizer = SFixed32Sizer = FloatSizer = _FixedSizer(4) +Fixed64Sizer = SFixed64Sizer = DoubleSizer = _FixedSizer(8) + +BoolSizer = _FixedSizer(1) + + +def StringSizer(field_number, is_repeated, is_packed): + """Returns a sizer for a string field.""" + + tag_size = _TagSize(field_number) + local_VarintSize = _VarintSize + local_len = len + assert not is_packed + if is_repeated: + def RepeatedFieldSize(value): + result = tag_size * len(value) + for element in value: + l = local_len(element.encode('utf-8')) + result += local_VarintSize(l) + l + return result + return RepeatedFieldSize + else: + def FieldSize(value): + l = local_len(value.encode('utf-8')) + return tag_size + local_VarintSize(l) + l + return FieldSize + + +def BytesSizer(field_number, is_repeated, is_packed): + """Returns a sizer for a bytes field.""" + + tag_size = _TagSize(field_number) + local_VarintSize = _VarintSize + local_len = len + assert not is_packed + if is_repeated: + def RepeatedFieldSize(value): + result = tag_size * len(value) + for element in value: + l = local_len(element) + result += local_VarintSize(l) + l + return result + return RepeatedFieldSize + else: + def FieldSize(value): + l = local_len(value) + return tag_size + local_VarintSize(l) + l + return FieldSize + + +def GroupSizer(field_number, is_repeated, is_packed): + """Returns a sizer for a group field.""" + + tag_size = _TagSize(field_number) * 2 + assert not is_packed + if is_repeated: + def RepeatedFieldSize(value): + result = tag_size * len(value) + for element in value: + result += element.ByteSize() + return result + return RepeatedFieldSize + else: + def FieldSize(value): + return tag_size + value.ByteSize() + return FieldSize + + +def MessageSizer(field_number, is_repeated, is_packed): + """Returns a sizer for a message field.""" + + tag_size = _TagSize(field_number) + local_VarintSize = _VarintSize + assert not is_packed + if is_repeated: + def RepeatedFieldSize(value): + result = tag_size * len(value) + for element in value: + l = element.ByteSize() + result += local_VarintSize(l) + l + return result + return RepeatedFieldSize + else: + def FieldSize(value): + l = value.ByteSize() + return tag_size + local_VarintSize(l) + l + return FieldSize + + +# -------------------------------------------------------------------- +# MessageSet is special: it needs custom logic to compute its size properly. + + +def MessageSetItemSizer(field_number): + """Returns a sizer for extensions of MessageSet. + + The message set message looks like this: + message MessageSet { + repeated group Item = 1 { + required int32 type_id = 2; + required string message = 3; + } + } + """ + static_size = (_TagSize(1) * 2 + _TagSize(2) + _VarintSize(field_number) + + _TagSize(3)) + local_VarintSize = _VarintSize + + def FieldSize(value): + l = value.ByteSize() + return static_size + local_VarintSize(l) + l + + return FieldSize + + +# -------------------------------------------------------------------- +# Map is special: it needs custom logic to compute its size properly. + + +def MapSizer(field_descriptor, is_message_map): + """Returns a sizer for a map field.""" + + # Can't look at field_descriptor.message_type._concrete_class because it may + # not have been initialized yet. + message_type = field_descriptor.message_type + message_sizer = MessageSizer(field_descriptor.number, False, False) + + def FieldSize(map_value): + total = 0 + for key in map_value: + value = map_value[key] + # It's wasteful to create the messages and throw them away one second + # later since we'll do the same for the actual encode. But there's not an + # obvious way to avoid this within the current design without tons of code + # duplication. For message map, value.ByteSize() should be called to + # update the status. + entry_msg = message_type._concrete_class(key=key, value=value) + total += message_sizer(entry_msg) + if is_message_map: + value.ByteSize() + return total + + return FieldSize + +# ==================================================================== +# Encoders! + + +def _VarintEncoder(): + """Return an encoder for a basic varint value (does not include tag).""" + + local_int2byte = six.int2byte + def EncodeVarint(write, value, unused_deterministic=None): + bits = value & 0x7f + value >>= 7 + while value: + write(local_int2byte(0x80|bits)) + bits = value & 0x7f + value >>= 7 + return write(local_int2byte(bits)) + + return EncodeVarint + + +def _SignedVarintEncoder(): + """Return an encoder for a basic signed varint value (does not include + tag).""" + + local_int2byte = six.int2byte + def EncodeSignedVarint(write, value, unused_deterministic=None): + if value < 0: + value += (1 << 64) + bits = value & 0x7f + value >>= 7 + while value: + write(local_int2byte(0x80|bits)) + bits = value & 0x7f + value >>= 7 + return write(local_int2byte(bits)) + + return EncodeSignedVarint + + +_EncodeVarint = _VarintEncoder() +_EncodeSignedVarint = _SignedVarintEncoder() + + +def _VarintBytes(value): + """Encode the given integer as a varint and return the bytes. This is only + called at startup time so it doesn't need to be fast.""" + + pieces = [] + _EncodeVarint(pieces.append, value, True) + return b"".join(pieces) + + +def TagBytes(field_number, wire_type): + """Encode the given tag and return the bytes. Only called at startup.""" + + return six.binary_type( + _VarintBytes(wire_format.PackTag(field_number, wire_type))) + +# -------------------------------------------------------------------- +# As with sizers (see above), we have a number of common encoder +# implementations. + + +def _SimpleEncoder(wire_type, encode_value, compute_value_size): + """Return a constructor for an encoder for fields of a particular type. + + Args: + wire_type: The field's wire type, for encoding tags. + encode_value: A function which encodes an individual value, e.g. + _EncodeVarint(). + compute_value_size: A function which computes the size of an individual + value, e.g. _VarintSize(). + """ + + def SpecificEncoder(field_number, is_repeated, is_packed): + if is_packed: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + def EncodePackedField(write, value, deterministic): + write(tag_bytes) + size = 0 + for element in value: + size += compute_value_size(element) + local_EncodeVarint(write, size, deterministic) + for element in value: + encode_value(write, element, deterministic) + return EncodePackedField + elif is_repeated: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeRepeatedField(write, value, deterministic): + for element in value: + write(tag_bytes) + encode_value(write, element, deterministic) + return EncodeRepeatedField + else: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeField(write, value, deterministic): + write(tag_bytes) + return encode_value(write, value, deterministic) + return EncodeField + + return SpecificEncoder + + +def _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_value): + """Like SimpleEncoder but additionally invokes modify_value on every value + before passing it to encode_value. Usually modify_value is ZigZagEncode.""" + + def SpecificEncoder(field_number, is_repeated, is_packed): + if is_packed: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + def EncodePackedField(write, value, deterministic): + write(tag_bytes) + size = 0 + for element in value: + size += compute_value_size(modify_value(element)) + local_EncodeVarint(write, size, deterministic) + for element in value: + encode_value(write, modify_value(element), deterministic) + return EncodePackedField + elif is_repeated: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeRepeatedField(write, value, deterministic): + for element in value: + write(tag_bytes) + encode_value(write, modify_value(element), deterministic) + return EncodeRepeatedField + else: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeField(write, value, deterministic): + write(tag_bytes) + return encode_value(write, modify_value(value), deterministic) + return EncodeField + + return SpecificEncoder + + +def _StructPackEncoder(wire_type, format): + """Return a constructor for an encoder for a fixed-width field. + + Args: + wire_type: The field's wire type, for encoding tags. + format: The format string to pass to struct.pack(). + """ + + value_size = struct.calcsize(format) + + def SpecificEncoder(field_number, is_repeated, is_packed): + local_struct_pack = struct.pack + if is_packed: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + def EncodePackedField(write, value, deterministic): + write(tag_bytes) + local_EncodeVarint(write, len(value) * value_size, deterministic) + for element in value: + write(local_struct_pack(format, element)) + return EncodePackedField + elif is_repeated: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeRepeatedField(write, value, unused_deterministic=None): + for element in value: + write(tag_bytes) + write(local_struct_pack(format, element)) + return EncodeRepeatedField + else: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeField(write, value, unused_deterministic=None): + write(tag_bytes) + return write(local_struct_pack(format, value)) + return EncodeField + + return SpecificEncoder + + +def _FloatingPointEncoder(wire_type, format): + """Return a constructor for an encoder for float fields. + + This is like StructPackEncoder, but catches errors that may be due to + passing non-finite floating-point values to struct.pack, and makes a + second attempt to encode those values. + + Args: + wire_type: The field's wire type, for encoding tags. + format: The format string to pass to struct.pack(). + """ + + value_size = struct.calcsize(format) + if value_size == 4: + def EncodeNonFiniteOrRaise(write, value): + # Remember that the serialized form uses little-endian byte order. + if value == _POS_INF: + write(b'\x00\x00\x80\x7F') + elif value == _NEG_INF: + write(b'\x00\x00\x80\xFF') + elif value != value: # NaN + write(b'\x00\x00\xC0\x7F') + else: + raise + elif value_size == 8: + def EncodeNonFiniteOrRaise(write, value): + if value == _POS_INF: + write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F') + elif value == _NEG_INF: + write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF') + elif value != value: # NaN + write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F') + else: + raise + else: + raise ValueError('Can\'t encode floating-point values that are ' + '%d bytes long (only 4 or 8)' % value_size) + + def SpecificEncoder(field_number, is_repeated, is_packed): + local_struct_pack = struct.pack + if is_packed: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + def EncodePackedField(write, value, deterministic): + write(tag_bytes) + local_EncodeVarint(write, len(value) * value_size, deterministic) + for element in value: + # This try/except block is going to be faster than any code that + # we could write to check whether element is finite. + try: + write(local_struct_pack(format, element)) + except SystemError: + EncodeNonFiniteOrRaise(write, element) + return EncodePackedField + elif is_repeated: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeRepeatedField(write, value, unused_deterministic=None): + for element in value: + write(tag_bytes) + try: + write(local_struct_pack(format, element)) + except SystemError: + EncodeNonFiniteOrRaise(write, element) + return EncodeRepeatedField + else: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeField(write, value, unused_deterministic=None): + write(tag_bytes) + try: + write(local_struct_pack(format, value)) + except SystemError: + EncodeNonFiniteOrRaise(write, value) + return EncodeField + + return SpecificEncoder + + +# ==================================================================== +# Here we declare an encoder constructor for each field type. These work +# very similarly to sizer constructors, described earlier. + + +Int32Encoder = Int64Encoder = EnumEncoder = _SimpleEncoder( + wire_format.WIRETYPE_VARINT, _EncodeSignedVarint, _SignedVarintSize) + +UInt32Encoder = UInt64Encoder = _SimpleEncoder( + wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize) + +SInt32Encoder = SInt64Encoder = _ModifiedEncoder( + wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize, + wire_format.ZigZagEncode) + +# Note that Python conveniently guarantees that when using the '<' prefix on +# formats, they will also have the same size across all platforms (as opposed +# to without the prefix, where their sizes depend on the C compiler's basic +# type sizes). +Fixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<I') +Fixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<Q') +SFixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<i') +SFixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<q') +FloatEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED32, '<f') +DoubleEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED64, '<d') + + +def BoolEncoder(field_number, is_repeated, is_packed): + """Returns an encoder for a boolean field.""" + + false_byte = b'\x00' + true_byte = b'\x01' + if is_packed: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + def EncodePackedField(write, value, deterministic): + write(tag_bytes) + local_EncodeVarint(write, len(value), deterministic) + for element in value: + if element: + write(true_byte) + else: + write(false_byte) + return EncodePackedField + elif is_repeated: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT) + def EncodeRepeatedField(write, value, unused_deterministic=None): + for element in value: + write(tag_bytes) + if element: + write(true_byte) + else: + write(false_byte) + return EncodeRepeatedField + else: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT) + def EncodeField(write, value, unused_deterministic=None): + write(tag_bytes) + if value: + return write(true_byte) + return write(false_byte) + return EncodeField + + +def StringEncoder(field_number, is_repeated, is_packed): + """Returns an encoder for a string field.""" + + tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + local_len = len + assert not is_packed + if is_repeated: + def EncodeRepeatedField(write, value, deterministic): + for element in value: + encoded = element.encode('utf-8') + write(tag) + local_EncodeVarint(write, local_len(encoded), deterministic) + write(encoded) + return EncodeRepeatedField + else: + def EncodeField(write, value, deterministic): + encoded = value.encode('utf-8') + write(tag) + local_EncodeVarint(write, local_len(encoded), deterministic) + return write(encoded) + return EncodeField + + +def BytesEncoder(field_number, is_repeated, is_packed): + """Returns an encoder for a bytes field.""" + + tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + local_len = len + assert not is_packed + if is_repeated: + def EncodeRepeatedField(write, value, deterministic): + for element in value: + write(tag) + local_EncodeVarint(write, local_len(element), deterministic) + write(element) + return EncodeRepeatedField + else: + def EncodeField(write, value, deterministic): + write(tag) + local_EncodeVarint(write, local_len(value), deterministic) + return write(value) + return EncodeField + + +def GroupEncoder(field_number, is_repeated, is_packed): + """Returns an encoder for a group field.""" + + start_tag = TagBytes(field_number, wire_format.WIRETYPE_START_GROUP) + end_tag = TagBytes(field_number, wire_format.WIRETYPE_END_GROUP) + assert not is_packed + if is_repeated: + def EncodeRepeatedField(write, value, deterministic): + for element in value: + write(start_tag) + element._InternalSerialize(write, deterministic) + write(end_tag) + return EncodeRepeatedField + else: + def EncodeField(write, value, deterministic): + write(start_tag) + value._InternalSerialize(write, deterministic) + return write(end_tag) + return EncodeField + + +def MessageEncoder(field_number, is_repeated, is_packed): + """Returns an encoder for a message field.""" + + tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + assert not is_packed + if is_repeated: + def EncodeRepeatedField(write, value, deterministic): + for element in value: + write(tag) + local_EncodeVarint(write, element.ByteSize(), deterministic) + element._InternalSerialize(write, deterministic) + return EncodeRepeatedField + else: + def EncodeField(write, value, deterministic): + write(tag) + local_EncodeVarint(write, value.ByteSize(), deterministic) + return value._InternalSerialize(write, deterministic) + return EncodeField + + +# -------------------------------------------------------------------- +# As before, MessageSet is special. + + +def MessageSetItemEncoder(field_number): + """Encoder for extensions of MessageSet. + + The message set message looks like this: + message MessageSet { + repeated group Item = 1 { + required int32 type_id = 2; + required string message = 3; + } + } + """ + start_bytes = b"".join([ + TagBytes(1, wire_format.WIRETYPE_START_GROUP), + TagBytes(2, wire_format.WIRETYPE_VARINT), + _VarintBytes(field_number), + TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)]) + end_bytes = TagBytes(1, wire_format.WIRETYPE_END_GROUP) + local_EncodeVarint = _EncodeVarint + + def EncodeField(write, value, deterministic): + write(start_bytes) + local_EncodeVarint(write, value.ByteSize(), deterministic) + value._InternalSerialize(write, deterministic) + return write(end_bytes) + + return EncodeField + + +# -------------------------------------------------------------------- +# As before, Map is special. + + +def MapEncoder(field_descriptor): + """Encoder for extensions of MessageSet. + + Maps always have a wire format like this: + message MapEntry { + key_type key = 1; + value_type value = 2; + } + repeated MapEntry map = N; + """ + # Can't look at field_descriptor.message_type._concrete_class because it may + # not have been initialized yet. + message_type = field_descriptor.message_type + encode_message = MessageEncoder(field_descriptor.number, False, False) + + def EncodeField(write, value, deterministic): + value_keys = sorted(value.keys()) if deterministic else value + for key in value_keys: + entry_msg = message_type._concrete_class(key=key, value=value[key]) + encode_message(write, entry_msg, deterministic) + + return EncodeField diff --git a/contrib/python/protobuf/py3/google/protobuf/internal/enum_type_wrapper.py b/contrib/python/protobuf/py3/google/protobuf/internal/enum_type_wrapper.py new file mode 100644 index 0000000000..9ae0066584 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/internal/enum_type_wrapper.py @@ -0,0 +1,117 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""A simple wrapper around enum types to expose utility functions. + +Instances are created as properties with the same name as the enum they wrap +on proto classes. For usage, see: + reflection_test.py +""" + +__author__ = 'rabsatt@google.com (Kevin Rabsatt)' + +import six + + +class EnumTypeWrapper(object): + """A utility for finding the names of enum values.""" + + DESCRIPTOR = None + + def __init__(self, enum_type): + """Inits EnumTypeWrapper with an EnumDescriptor.""" + self._enum_type = enum_type + self.DESCRIPTOR = enum_type # pylint: disable=invalid-name + + def Name(self, number): # pylint: disable=invalid-name + """Returns a string containing the name of an enum value.""" + try: + return self._enum_type.values_by_number[number].name + except KeyError: + pass # fall out to break exception chaining + + if not isinstance(number, six.integer_types): + raise TypeError( + 'Enum value for {} must be an int, but got {} {!r}.'.format( + self._enum_type.name, type(number), number)) + else: + # repr here to handle the odd case when you pass in a boolean. + raise ValueError('Enum {} has no name defined for value {!r}'.format( + self._enum_type.name, number)) + + def Value(self, name): # pylint: disable=invalid-name + """Returns the value corresponding to the given enum name.""" + try: + return self._enum_type.values_by_name[name].number + except KeyError: + pass # fall out to break exception chaining + raise ValueError('Enum {} has no value defined for name {!r}'.format( + self._enum_type.name, name)) + + def keys(self): + """Return a list of the string names in the enum. + + Returns: + A list of strs, in the order they were defined in the .proto file. + """ + + return [value_descriptor.name + for value_descriptor in self._enum_type.values] + + def values(self): + """Return a list of the integer values in the enum. + + Returns: + A list of ints, in the order they were defined in the .proto file. + """ + + return [value_descriptor.number + for value_descriptor in self._enum_type.values] + + def items(self): + """Return a list of the (name, value) pairs of the enum. + + Returns: + A list of (str, int) pairs, in the order they were defined + in the .proto file. + """ + return [(value_descriptor.name, value_descriptor.number) + for value_descriptor in self._enum_type.values] + + def __getattr__(self, name): + """Returns the value corresponding to the given enum name.""" + try: + return super( + EnumTypeWrapper, + self).__getattribute__('_enum_type').values_by_name[name].number + except KeyError: + pass # fall out to break exception chaining + raise AttributeError('Enum {} has no value defined for name {!r}'.format( + self._enum_type.name, name)) diff --git a/contrib/python/protobuf/py3/google/protobuf/internal/extension_dict.py b/contrib/python/protobuf/py3/google/protobuf/internal/extension_dict.py new file mode 100644 index 0000000000..b346cf283e --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/internal/extension_dict.py @@ -0,0 +1,213 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Contains _ExtensionDict class to represent extensions. +""" + +from google.protobuf.internal import type_checkers +from google.protobuf.descriptor import FieldDescriptor + + +def _VerifyExtensionHandle(message, extension_handle): + """Verify that the given extension handle is valid.""" + + if not isinstance(extension_handle, FieldDescriptor): + raise KeyError('HasExtension() expects an extension handle, got: %s' % + extension_handle) + + if not extension_handle.is_extension: + raise KeyError('"%s" is not an extension.' % extension_handle.full_name) + + if not extension_handle.containing_type: + raise KeyError('"%s" is missing a containing_type.' + % extension_handle.full_name) + + if extension_handle.containing_type is not message.DESCRIPTOR: + raise KeyError('Extension "%s" extends message type "%s", but this ' + 'message is of type "%s".' % + (extension_handle.full_name, + extension_handle.containing_type.full_name, + message.DESCRIPTOR.full_name)) + + +# TODO(robinson): Unify error handling of "unknown extension" crap. +# TODO(robinson): Support iteritems()-style iteration over all +# extensions with the "has" bits turned on? +class _ExtensionDict(object): + + """Dict-like container for Extension fields on proto instances. + + Note that in all cases we expect extension handles to be + FieldDescriptors. + """ + + def __init__(self, extended_message): + """ + Args: + extended_message: Message instance for which we are the Extensions dict. + """ + self._extended_message = extended_message + + def __getitem__(self, extension_handle): + """Returns the current value of the given extension handle.""" + + _VerifyExtensionHandle(self._extended_message, extension_handle) + + result = self._extended_message._fields.get(extension_handle) + if result is not None: + return result + + if extension_handle.label == FieldDescriptor.LABEL_REPEATED: + result = extension_handle._default_constructor(self._extended_message) + elif extension_handle.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: + message_type = extension_handle.message_type + if not hasattr(message_type, '_concrete_class'): + # pylint: disable=protected-access + self._extended_message._FACTORY.GetPrototype(message_type) + assert getattr(extension_handle.message_type, '_concrete_class', None), ( + 'Uninitialized concrete class found for field %r (message type %r)' + % (extension_handle.full_name, + extension_handle.message_type.full_name)) + result = extension_handle.message_type._concrete_class() + try: + result._SetListener(self._extended_message._listener_for_children) + except ReferenceError: + pass + else: + # Singular scalar -- just return the default without inserting into the + # dict. + return extension_handle.default_value + + # Atomically check if another thread has preempted us and, if not, swap + # in the new object we just created. If someone has preempted us, we + # take that object and discard ours. + # WARNING: We are relying on setdefault() being atomic. This is true + # in CPython but we haven't investigated others. This warning appears + # in several other locations in this file. + result = self._extended_message._fields.setdefault( + extension_handle, result) + + return result + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + + my_fields = self._extended_message.ListFields() + other_fields = other._extended_message.ListFields() + + # Get rid of non-extension fields. + my_fields = [field for field in my_fields if field.is_extension] + other_fields = [field for field in other_fields if field.is_extension] + + return my_fields == other_fields + + def __ne__(self, other): + return not self == other + + def __len__(self): + fields = self._extended_message.ListFields() + # Get rid of non-extension fields. + extension_fields = [field for field in fields if field[0].is_extension] + return len(extension_fields) + + def __hash__(self): + raise TypeError('unhashable object') + + # Note that this is only meaningful for non-repeated, scalar extension + # fields. Note also that we may have to call _Modified() when we do + # successfully set a field this way, to set any necessary "has" bits in the + # ancestors of the extended message. + def __setitem__(self, extension_handle, value): + """If extension_handle specifies a non-repeated, scalar extension + field, sets the value of that field. + """ + + _VerifyExtensionHandle(self._extended_message, extension_handle) + + if (extension_handle.label == FieldDescriptor.LABEL_REPEATED or + extension_handle.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE): + raise TypeError( + 'Cannot assign to extension "%s" because it is a repeated or ' + 'composite type.' % extension_handle.full_name) + + # It's slightly wasteful to lookup the type checker each time, + # but we expect this to be a vanishingly uncommon case anyway. + type_checker = type_checkers.GetTypeChecker(extension_handle) + # pylint: disable=protected-access + self._extended_message._fields[extension_handle] = ( + type_checker.CheckValue(value)) + self._extended_message._Modified() + + def __delitem__(self, extension_handle): + self._extended_message.ClearExtension(extension_handle) + + def _FindExtensionByName(self, name): + """Tries to find a known extension with the specified name. + + Args: + name: Extension full name. + + Returns: + Extension field descriptor. + """ + return self._extended_message._extensions_by_name.get(name, None) + + def _FindExtensionByNumber(self, number): + """Tries to find a known extension with the field number. + + Args: + number: Extension field number. + + Returns: + Extension field descriptor. + """ + return self._extended_message._extensions_by_number.get(number, None) + + def __iter__(self): + # Return a generator over the populated extension fields + return (f[0] for f in self._extended_message.ListFields() + if f[0].is_extension) + + def __contains__(self, extension_handle): + _VerifyExtensionHandle(self._extended_message, extension_handle) + + if extension_handle not in self._extended_message._fields: + return False + + if extension_handle.label == FieldDescriptor.LABEL_REPEATED: + return bool(self._extended_message._fields.get(extension_handle)) + + if extension_handle.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: + value = self._extended_message._fields.get(extension_handle) + # pylint: disable=protected-access + return value is not None and value._is_present_in_parent + + return True diff --git a/contrib/python/protobuf/py3/google/protobuf/internal/message_listener.py b/contrib/python/protobuf/py3/google/protobuf/internal/message_listener.py new file mode 100644 index 0000000000..0fc255a774 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/internal/message_listener.py @@ -0,0 +1,78 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Defines a listener interface for observing certain +state transitions on Message objects. + +Also defines a null implementation of this interface. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + + +class MessageListener(object): + + """Listens for modifications made to a message. Meant to be registered via + Message._SetListener(). + + Attributes: + dirty: If True, then calling Modified() would be a no-op. This can be + used to avoid these calls entirely in the common case. + """ + + def Modified(self): + """Called every time the message is modified in such a way that the parent + message may need to be updated. This currently means either: + (a) The message was modified for the first time, so the parent message + should henceforth mark the message as present. + (b) The message's cached byte size became dirty -- i.e. the message was + modified for the first time after a previous call to ByteSize(). + Therefore the parent should also mark its byte size as dirty. + Note that (a) implies (b), since new objects start out with a client cached + size (zero). However, we document (a) explicitly because it is important. + + Modified() will *only* be called in response to one of these two events -- + not every time the sub-message is modified. + + Note that if the listener's |dirty| attribute is true, then calling + Modified at the moment would be a no-op, so it can be skipped. Performance- + sensitive callers should check this attribute directly before calling since + it will be true most of the time. + """ + + raise NotImplementedError + + +class NullMessageListener(object): + + """No-op MessageListener implementation.""" + + def Modified(self): + pass diff --git a/contrib/python/protobuf/py3/google/protobuf/internal/python_message.py b/contrib/python/protobuf/py3/google/protobuf/internal/python_message.py new file mode 100644 index 0000000000..99d2f078de --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/internal/python_message.py @@ -0,0 +1,1541 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# This code is meant to work on Python 2.4 and above only. +# +# TODO(robinson): Helpers for verbose, common checks like seeing if a +# descriptor's cpp_type is CPPTYPE_MESSAGE. + +"""Contains a metaclass and helper functions used to create +protocol message classes from Descriptor objects at runtime. + +Recall that a metaclass is the "type" of a class. +(A class is to a metaclass what an instance is to a class.) + +In this case, we use the GeneratedProtocolMessageType metaclass +to inject all the useful functionality into the classes +output by the protocol compiler at compile-time. + +The upshot of all this is that the real implementation +details for ALL pure-Python protocol buffers are *here in +this file*. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +from io import BytesIO +import struct +import sys +import weakref + +import six +from six.moves import range + +# We use "as" to avoid name collisions with variables. +from google.protobuf.internal import api_implementation +from google.protobuf.internal import containers +from google.protobuf.internal import decoder +from google.protobuf.internal import encoder +from google.protobuf.internal import enum_type_wrapper +from google.protobuf.internal import extension_dict +from google.protobuf.internal import message_listener as message_listener_mod +from google.protobuf.internal import type_checkers +from google.protobuf.internal import well_known_types +from google.protobuf.internal import wire_format +from google.protobuf import descriptor as descriptor_mod +from google.protobuf import message as message_mod +from google.protobuf import text_format + +_FieldDescriptor = descriptor_mod.FieldDescriptor +_AnyFullTypeName = 'google.protobuf.Any' +_ExtensionDict = extension_dict._ExtensionDict + +class GeneratedProtocolMessageType(type): + + """Metaclass for protocol message classes created at runtime from Descriptors. + + We add implementations for all methods described in the Message class. We + also create properties to allow getting/setting all fields in the protocol + message. Finally, we create slots to prevent users from accidentally + "setting" nonexistent fields in the protocol message, which then wouldn't get + serialized / deserialized properly. + + The protocol compiler currently uses this metaclass to create protocol + message classes at runtime. Clients can also manually create their own + classes at runtime, as in this example: + + mydescriptor = Descriptor(.....) + factory = symbol_database.Default() + factory.pool.AddDescriptor(mydescriptor) + MyProtoClass = factory.GetPrototype(mydescriptor) + myproto_instance = MyProtoClass() + myproto.foo_field = 23 + ... + """ + + # Must be consistent with the protocol-compiler code in + # proto2/compiler/internal/generator.*. + _DESCRIPTOR_KEY = 'DESCRIPTOR' + + def __new__(cls, name, bases, dictionary): + """Custom allocation for runtime-generated class types. + + We override __new__ because this is apparently the only place + where we can meaningfully set __slots__ on the class we're creating(?). + (The interplay between metaclasses and slots is not very well-documented). + + Args: + name: Name of the class (ignored, but required by the + metaclass protocol). + bases: Base classes of the class we're constructing. + (Should be message.Message). We ignore this field, but + it's required by the metaclass protocol + dictionary: The class dictionary of the class we're + constructing. dictionary[_DESCRIPTOR_KEY] must contain + a Descriptor object describing this protocol message + type. + + Returns: + Newly-allocated class. + + Raises: + RuntimeError: Generated code only work with python cpp extension. + """ + descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + + if isinstance(descriptor, str): + raise RuntimeError('The generated code only work with python cpp ' + 'extension, but it is using pure python runtime.') + + # If a concrete class already exists for this descriptor, don't try to + # create another. Doing so will break any messages that already exist with + # the existing class. + # + # The C++ implementation appears to have its own internal `PyMessageFactory` + # to achieve similar results. + # + # This most commonly happens in `text_format.py` when using descriptors from + # a custom pool; it calls symbol_database.Global().getPrototype() on a + # descriptor which already has an existing concrete class. + new_class = getattr(descriptor, '_concrete_class', None) + if new_class: + return new_class + + if descriptor.full_name in well_known_types.WKTBASES: + bases += (well_known_types.WKTBASES[descriptor.full_name],) + _AddClassAttributesForNestedExtensions(descriptor, dictionary) + _AddSlots(descriptor, dictionary) + + superclass = super(GeneratedProtocolMessageType, cls) + new_class = superclass.__new__(cls, name, bases, dictionary) + return new_class + + def __init__(cls, name, bases, dictionary): + """Here we perform the majority of our work on the class. + We add enum getters, an __init__ method, implementations + of all Message methods, and properties for all fields + in the protocol type. + + Args: + name: Name of the class (ignored, but required by the + metaclass protocol). + bases: Base classes of the class we're constructing. + (Should be message.Message). We ignore this field, but + it's required by the metaclass protocol + dictionary: The class dictionary of the class we're + constructing. dictionary[_DESCRIPTOR_KEY] must contain + a Descriptor object describing this protocol message + type. + """ + descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + + # If this is an _existing_ class looked up via `_concrete_class` in the + # __new__ method above, then we don't need to re-initialize anything. + existing_class = getattr(descriptor, '_concrete_class', None) + if existing_class: + assert existing_class is cls, ( + 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r' + % (descriptor.full_name)) + return + + cls._decoders_by_tag = {} + if (descriptor.has_options and + descriptor.GetOptions().message_set_wire_format): + cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( + decoder.MessageSetItemDecoder(descriptor), None) + + # Attach stuff to each FieldDescriptor for quick lookup later on. + for field in descriptor.fields: + _AttachFieldHelpers(cls, field) + + descriptor._concrete_class = cls # pylint: disable=protected-access + _AddEnumValues(descriptor, cls) + _AddInitMethod(descriptor, cls) + _AddPropertiesForFields(descriptor, cls) + _AddPropertiesForExtensions(descriptor, cls) + _AddStaticMethods(cls) + _AddMessageMethods(descriptor, cls) + _AddPrivateHelperMethods(descriptor, cls) + + superclass = super(GeneratedProtocolMessageType, cls) + superclass.__init__(name, bases, dictionary) + + +# Stateless helpers for GeneratedProtocolMessageType below. +# Outside clients should not access these directly. +# +# I opted not to make any of these methods on the metaclass, to make it more +# clear that I'm not really using any state there and to keep clients from +# thinking that they have direct access to these construction helpers. + + +def _PropertyName(proto_field_name): + """Returns the name of the public property attribute which + clients can use to get and (in some cases) set the value + of a protocol message field. + + Args: + proto_field_name: The protocol message field name, exactly + as it appears (or would appear) in a .proto file. + """ + # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. + # nnorwitz makes my day by writing: + # """ + # FYI. See the keyword module in the stdlib. This could be as simple as: + # + # if keyword.iskeyword(proto_field_name): + # return proto_field_name + "_" + # return proto_field_name + # """ + # Kenton says: The above is a BAD IDEA. People rely on being able to use + # getattr() and setattr() to reflectively manipulate field values. If we + # rename the properties, then every such user has to also make sure to apply + # the same transformation. Note that currently if you name a field "yield", + # you can still access it just fine using getattr/setattr -- it's not even + # that cumbersome to do so. + # TODO(kenton): Remove this method entirely if/when everyone agrees with my + # position. + return proto_field_name + + +def _AddSlots(message_descriptor, dictionary): + """Adds a __slots__ entry to dictionary, containing the names of all valid + attributes for this message type. + + Args: + message_descriptor: A Descriptor instance describing this message type. + dictionary: Class dictionary to which we'll add a '__slots__' entry. + """ + dictionary['__slots__'] = ['_cached_byte_size', + '_cached_byte_size_dirty', + '_fields', + '_unknown_fields', + '_unknown_field_set', + '_is_present_in_parent', + '_listener', + '_listener_for_children', + '__weakref__', + '_oneofs'] + + +def _IsMessageSetExtension(field): + return (field.is_extension and + field.containing_type.has_options and + field.containing_type.GetOptions().message_set_wire_format and + field.type == _FieldDescriptor.TYPE_MESSAGE and + field.label == _FieldDescriptor.LABEL_OPTIONAL) + + +def _IsMapField(field): + return (field.type == _FieldDescriptor.TYPE_MESSAGE and + field.message_type.has_options and + field.message_type.GetOptions().map_entry) + + +def _IsMessageMapField(field): + value_type = field.message_type.fields_by_name['value'] + return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE + + +def _IsStrictUtf8Check(field): + if field.containing_type.syntax != 'proto3': + return False + enforce_utf8 = True + return enforce_utf8 + + +def _AttachFieldHelpers(cls, field_descriptor): + is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) + is_packable = (is_repeated and + wire_format.IsTypePackable(field_descriptor.type)) + is_proto3 = field_descriptor.containing_type.syntax == 'proto3' + if not is_packable: + is_packed = False + elif field_descriptor.containing_type.syntax == 'proto2': + is_packed = (field_descriptor.has_options and + field_descriptor.GetOptions().packed) + else: + has_packed_false = (field_descriptor.has_options and + field_descriptor.GetOptions().HasField('packed') and + field_descriptor.GetOptions().packed == False) + is_packed = not has_packed_false + is_map_entry = _IsMapField(field_descriptor) + + if is_map_entry: + field_encoder = encoder.MapEncoder(field_descriptor) + sizer = encoder.MapSizer(field_descriptor, + _IsMessageMapField(field_descriptor)) + elif _IsMessageSetExtension(field_descriptor): + field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) + sizer = encoder.MessageSetItemSizer(field_descriptor.number) + else: + field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( + field_descriptor.number, is_repeated, is_packed) + sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( + field_descriptor.number, is_repeated, is_packed) + + field_descriptor._encoder = field_encoder + field_descriptor._sizer = sizer + field_descriptor._default_constructor = _DefaultValueConstructorForField( + field_descriptor) + + def AddDecoder(wiretype, is_packed): + tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) + decode_type = field_descriptor.type + if (decode_type == _FieldDescriptor.TYPE_ENUM and + type_checkers.SupportsOpenEnums(field_descriptor)): + decode_type = _FieldDescriptor.TYPE_INT32 + + oneof_descriptor = None + clear_if_default = False + if field_descriptor.containing_oneof is not None: + oneof_descriptor = field_descriptor + elif (is_proto3 and not is_repeated and + field_descriptor.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE): + clear_if_default = True + + if is_map_entry: + is_message_map = _IsMessageMapField(field_descriptor) + + field_decoder = decoder.MapDecoder( + field_descriptor, _GetInitializeDefaultForMap(field_descriptor), + is_message_map) + elif decode_type == _FieldDescriptor.TYPE_STRING: + is_strict_utf8_check = _IsStrictUtf8Check(field_descriptor) + field_decoder = decoder.StringDecoder( + field_descriptor.number, is_repeated, is_packed, + field_descriptor, field_descriptor._default_constructor, + is_strict_utf8_check, clear_if_default) + elif field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( + field_descriptor.number, is_repeated, is_packed, + field_descriptor, field_descriptor._default_constructor) + else: + field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( + field_descriptor.number, is_repeated, is_packed, + # pylint: disable=protected-access + field_descriptor, field_descriptor._default_constructor, + clear_if_default) + + cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) + + AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], + False) + + if is_repeated and wire_format.IsTypePackable(field_descriptor.type): + # To support wire compatibility of adding packed = true, add a decoder for + # packed values regardless of the field's options. + AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) + + +def _AddClassAttributesForNestedExtensions(descriptor, dictionary): + extensions = descriptor.extensions_by_name + for extension_name, extension_field in extensions.items(): + assert extension_name not in dictionary + dictionary[extension_name] = extension_field + + +def _AddEnumValues(descriptor, cls): + """Sets class-level attributes for all enum fields defined in this message. + + Also exporting a class-level object that can name enum values. + + Args: + descriptor: Descriptor object for this message type. + cls: Class we're constructing for this message type. + """ + for enum_type in descriptor.enum_types: + setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) + for enum_value in enum_type.values: + setattr(cls, enum_value.name, enum_value.number) + + +def _GetInitializeDefaultForMap(field): + if field.label != _FieldDescriptor.LABEL_REPEATED: + raise ValueError('map_entry set on non-repeated field %s' % ( + field.name)) + fields_by_name = field.message_type.fields_by_name + key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) + + value_field = fields_by_name['value'] + if _IsMessageMapField(field): + def MakeMessageMapDefault(message): + return containers.MessageMap( + message._listener_for_children, value_field.message_type, key_checker, + field.message_type) + return MakeMessageMapDefault + else: + value_checker = type_checkers.GetTypeChecker(value_field) + def MakePrimitiveMapDefault(message): + return containers.ScalarMap( + message._listener_for_children, key_checker, value_checker, + field.message_type) + return MakePrimitiveMapDefault + +def _DefaultValueConstructorForField(field): + """Returns a function which returns a default value for a field. + + Args: + field: FieldDescriptor object for this field. + + The returned function has one argument: + message: Message instance containing this field, or a weakref proxy + of same. + + That function in turn returns a default value for this field. The default + value may refer back to |message| via a weak reference. + """ + + if _IsMapField(field): + return _GetInitializeDefaultForMap(field) + + if field.label == _FieldDescriptor.LABEL_REPEATED: + if field.has_default_value and field.default_value != []: + raise ValueError('Repeated field default value not empty list: %s' % ( + field.default_value)) + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + # We can't look at _concrete_class yet since it might not have + # been set. (Depends on order in which we initialize the classes). + message_type = field.message_type + def MakeRepeatedMessageDefault(message): + return containers.RepeatedCompositeFieldContainer( + message._listener_for_children, field.message_type) + return MakeRepeatedMessageDefault + else: + type_checker = type_checkers.GetTypeChecker(field) + def MakeRepeatedScalarDefault(message): + return containers.RepeatedScalarFieldContainer( + message._listener_for_children, type_checker) + return MakeRepeatedScalarDefault + + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + # _concrete_class may not yet be initialized. + message_type = field.message_type + def MakeSubMessageDefault(message): + assert getattr(message_type, '_concrete_class', None), ( + 'Uninitialized concrete class found for field %r (message type %r)' + % (field.full_name, message_type.full_name)) + result = message_type._concrete_class() + result._SetListener( + _OneofListener(message, field) + if field.containing_oneof is not None + else message._listener_for_children) + return result + return MakeSubMessageDefault + + def MakeScalarDefault(message): + # TODO(protobuf-team): This may be broken since there may not be + # default_value. Combine with has_default_value somehow. + return field.default_value + return MakeScalarDefault + + +def _ReraiseTypeErrorWithFieldName(message_name, field_name): + """Re-raise the currently-handled TypeError with the field name added.""" + exc = sys.exc_info()[1] + if len(exc.args) == 1 and type(exc) is TypeError: + # simple TypeError; add field name to exception message + exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name)) + + # re-raise possibly-amended exception with original traceback: + six.reraise(type(exc), exc, sys.exc_info()[2]) + + +def _AddInitMethod(message_descriptor, cls): + """Adds an __init__ method to cls.""" + + def _GetIntegerEnumValue(enum_type, value): + """Convert a string or integer enum value to an integer. + + If the value is a string, it is converted to the enum value in + enum_type with the same name. If the value is not a string, it's + returned as-is. (No conversion or bounds-checking is done.) + """ + if isinstance(value, six.string_types): + try: + return enum_type.values_by_name[value].number + except KeyError: + raise ValueError('Enum type %s: unknown label "%s"' % ( + enum_type.full_name, value)) + return value + + def init(self, **kwargs): + self._cached_byte_size = 0 + self._cached_byte_size_dirty = len(kwargs) > 0 + self._fields = {} + # Contains a mapping from oneof field descriptors to the descriptor + # of the currently set field in that oneof field. + self._oneofs = {} + + # _unknown_fields is () when empty for efficiency, and will be turned into + # a list if fields are added. + self._unknown_fields = () + # _unknown_field_set is None when empty for efficiency, and will be + # turned into UnknownFieldSet struct if fields are added. + self._unknown_field_set = None # pylint: disable=protected-access + self._is_present_in_parent = False + self._listener = message_listener_mod.NullMessageListener() + self._listener_for_children = _Listener(self) + for field_name, field_value in kwargs.items(): + field = _GetFieldByName(message_descriptor, field_name) + if field is None: + raise TypeError('%s() got an unexpected keyword argument "%s"' % + (message_descriptor.name, field_name)) + if field_value is None: + # field=None is the same as no field at all. + continue + if field.label == _FieldDescriptor.LABEL_REPEATED: + copy = field._default_constructor(self) + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite + if _IsMapField(field): + if _IsMessageMapField(field): + for key in field_value: + copy[key].MergeFrom(field_value[key]) + else: + copy.update(field_value) + else: + for val in field_value: + if isinstance(val, dict): + copy.add(**val) + else: + copy.add().MergeFrom(val) + else: # Scalar + if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: + field_value = [_GetIntegerEnumValue(field.enum_type, val) + for val in field_value] + copy.extend(field_value) + self._fields[field] = copy + elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + copy = field._default_constructor(self) + new_val = field_value + if isinstance(field_value, dict): + new_val = field.message_type._concrete_class(**field_value) + try: + copy.MergeFrom(new_val) + except TypeError: + _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) + self._fields[field] = copy + else: + if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: + field_value = _GetIntegerEnumValue(field.enum_type, field_value) + try: + setattr(self, field_name, field_value) + except TypeError: + _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) + + init.__module__ = None + init.__doc__ = None + cls.__init__ = init + + +def _GetFieldByName(message_descriptor, field_name): + """Returns a field descriptor by field name. + + Args: + message_descriptor: A Descriptor describing all fields in message. + field_name: The name of the field to retrieve. + Returns: + The field descriptor associated with the field name. + """ + try: + return message_descriptor.fields_by_name[field_name] + except KeyError: + raise ValueError('Protocol message %s has no "%s" field.' % + (message_descriptor.name, field_name)) + + +def _AddPropertiesForFields(descriptor, cls): + """Adds properties for all fields in this protocol message type.""" + for field in descriptor.fields: + _AddPropertiesForField(field, cls) + + if descriptor.is_extendable: + # _ExtensionDict is just an adaptor with no state so we allocate a new one + # every time it is accessed. + cls.Extensions = property(lambda self: _ExtensionDict(self)) + + +def _AddPropertiesForField(field, cls): + """Adds a public property for a protocol message field. + Clients can use this property to get and (in the case + of non-repeated scalar fields) directly set the value + of a protocol message field. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + # Catch it if we add other types that we should + # handle specially here. + assert _FieldDescriptor.MAX_CPPTYPE == 10 + + constant_name = field.name.upper() + '_FIELD_NUMBER' + setattr(cls, constant_name, field.number) + + if field.label == _FieldDescriptor.LABEL_REPEATED: + _AddPropertiesForRepeatedField(field, cls) + elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + _AddPropertiesForNonRepeatedCompositeField(field, cls) + else: + _AddPropertiesForNonRepeatedScalarField(field, cls) + + +class _FieldProperty(property): + __slots__ = ('DESCRIPTOR',) + + def __init__(self, descriptor, getter, setter, doc): + property.__init__(self, getter, setter, doc=doc) + self.DESCRIPTOR = descriptor + + +def _AddPropertiesForRepeatedField(field, cls): + """Adds a public property for a "repeated" protocol message field. Clients + can use this property to get the value of the field, which will be either a + RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see + below). + + Note that when clients add values to these containers, we perform + type-checking in the case of repeated scalar fields, and we also set any + necessary "has" bits as a side-effect. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + proto_field_name = field.name + property_name = _PropertyName(proto_field_name) + + def getter(self): + field_value = self._fields.get(field) + if field_value is None: + # Construct a new object to represent this field. + field_value = field._default_constructor(self) + + # Atomically check if another thread has preempted us and, if not, swap + # in the new object we just created. If someone has preempted us, we + # take that object and discard ours. + # WARNING: We are relying on setdefault() being atomic. This is true + # in CPython but we haven't investigated others. This warning appears + # in several other locations in this file. + field_value = self._fields.setdefault(field, field_value) + return field_value + getter.__module__ = None + getter.__doc__ = 'Getter for %s.' % proto_field_name + + # We define a setter just so we can throw an exception with a more + # helpful error message. + def setter(self, new_value): + raise AttributeError('Assignment not allowed to repeated field ' + '"%s" in protocol message object.' % proto_field_name) + + doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name + setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) + + +def _AddPropertiesForNonRepeatedScalarField(field, cls): + """Adds a public property for a nonrepeated, scalar protocol message field. + Clients can use this property to get and directly set the value of the field. + Note that when the client sets the value of a field by using this property, + all necessary "has" bits are set as a side-effect, and we also perform + type-checking. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + proto_field_name = field.name + property_name = _PropertyName(proto_field_name) + type_checker = type_checkers.GetTypeChecker(field) + default_value = field.default_value + is_proto3 = field.containing_type.syntax == 'proto3' + + def getter(self): + # TODO(protobuf-team): This may be broken since there may not be + # default_value. Combine with has_default_value somehow. + return self._fields.get(field, default_value) + getter.__module__ = None + getter.__doc__ = 'Getter for %s.' % proto_field_name + + clear_when_set_to_default = is_proto3 and not field.containing_oneof + + def field_setter(self, new_value): + # pylint: disable=protected-access + # Testing the value for truthiness captures all of the proto3 defaults + # (0, 0.0, enum 0, and False). + try: + new_value = type_checker.CheckValue(new_value) + except TypeError as e: + raise TypeError( + 'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e)) + if clear_when_set_to_default and not new_value: + self._fields.pop(field, None) + else: + self._fields[field] = new_value + # Check _cached_byte_size_dirty inline to improve performance, since scalar + # setters are called frequently. + if not self._cached_byte_size_dirty: + self._Modified() + + if field.containing_oneof: + def setter(self, new_value): + field_setter(self, new_value) + self._UpdateOneofState(field) + else: + setter = field_setter + + setter.__module__ = None + setter.__doc__ = 'Setter for %s.' % proto_field_name + + # Add a property to encapsulate the getter/setter. + doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name + setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) + + +def _AddPropertiesForNonRepeatedCompositeField(field, cls): + """Adds a public property for a nonrepeated, composite protocol message field. + A composite field is a "group" or "message" field. + + Clients can use this property to get the value of the field, but cannot + assign to the property directly. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + # TODO(robinson): Remove duplication with similar method + # for non-repeated scalars. + proto_field_name = field.name + property_name = _PropertyName(proto_field_name) + + def getter(self): + field_value = self._fields.get(field) + if field_value is None: + # Construct a new object to represent this field. + field_value = field._default_constructor(self) + + # Atomically check if another thread has preempted us and, if not, swap + # in the new object we just created. If someone has preempted us, we + # take that object and discard ours. + # WARNING: We are relying on setdefault() being atomic. This is true + # in CPython but we haven't investigated others. This warning appears + # in several other locations in this file. + field_value = self._fields.setdefault(field, field_value) + return field_value + getter.__module__ = None + getter.__doc__ = 'Getter for %s.' % proto_field_name + + # We define a setter just so we can throw an exception with a more + # helpful error message. + def setter(self, new_value): + raise AttributeError('Assignment not allowed to composite field ' + '"%s" in protocol message object.' % proto_field_name) + + # Add a property to encapsulate the getter. + doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name + setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) + + +def _AddPropertiesForExtensions(descriptor, cls): + """Adds properties for all fields in this protocol message type.""" + extensions = descriptor.extensions_by_name + for extension_name, extension_field in extensions.items(): + constant_name = extension_name.upper() + '_FIELD_NUMBER' + setattr(cls, constant_name, extension_field.number) + + # TODO(amauryfa): Migrate all users of these attributes to functions like + # pool.FindExtensionByNumber(descriptor). + if descriptor.file is not None: + # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available. + pool = descriptor.file.pool + cls._extensions_by_number = pool._extensions_by_number[descriptor] + cls._extensions_by_name = pool._extensions_by_name[descriptor] + +def _AddStaticMethods(cls): + # TODO(robinson): This probably needs to be thread-safe(?) + def RegisterExtension(extension_handle): + extension_handle.containing_type = cls.DESCRIPTOR + # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available. + # pylint: disable=protected-access + cls.DESCRIPTOR.file.pool._AddExtensionDescriptor(extension_handle) + _AttachFieldHelpers(cls, extension_handle) + cls.RegisterExtension = staticmethod(RegisterExtension) + + def FromString(s): + message = cls() + message.MergeFromString(s) + return message + cls.FromString = staticmethod(FromString) + + +def _IsPresent(item): + """Given a (FieldDescriptor, value) tuple from _fields, return true if the + value should be included in the list returned by ListFields().""" + + if item[0].label == _FieldDescriptor.LABEL_REPEATED: + return bool(item[1]) + elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + return item[1]._is_present_in_parent + else: + return True + + +def _AddListFieldsMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def ListFields(self): + all_fields = [item for item in self._fields.items() if _IsPresent(item)] + all_fields.sort(key = lambda item: item[0].number) + return all_fields + + cls.ListFields = ListFields + +_PROTO3_ERROR_TEMPLATE = \ + ('Protocol message %s has no non-repeated submessage field "%s" ' + 'nor marked as optional') +_PROTO2_ERROR_TEMPLATE = 'Protocol message %s has no non-repeated field "%s"' + +def _AddHasFieldMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + is_proto3 = (message_descriptor.syntax == "proto3") + error_msg = _PROTO3_ERROR_TEMPLATE if is_proto3 else _PROTO2_ERROR_TEMPLATE + + hassable_fields = {} + for field in message_descriptor.fields: + if field.label == _FieldDescriptor.LABEL_REPEATED: + continue + # For proto3, only submessages and fields inside a oneof have presence. + if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and + not field.containing_oneof): + continue + hassable_fields[field.name] = field + + # Has methods are supported for oneof descriptors. + for oneof in message_descriptor.oneofs: + hassable_fields[oneof.name] = oneof + + def HasField(self, field_name): + try: + field = hassable_fields[field_name] + except KeyError: + raise ValueError(error_msg % (message_descriptor.full_name, field_name)) + + if isinstance(field, descriptor_mod.OneofDescriptor): + try: + return HasField(self, self._oneofs[field].name) + except KeyError: + return False + else: + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + value = self._fields.get(field) + return value is not None and value._is_present_in_parent + else: + return field in self._fields + + cls.HasField = HasField + + +def _AddClearFieldMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def ClearField(self, field_name): + try: + field = message_descriptor.fields_by_name[field_name] + except KeyError: + try: + field = message_descriptor.oneofs_by_name[field_name] + if field in self._oneofs: + field = self._oneofs[field] + else: + return + except KeyError: + raise ValueError('Protocol message %s has no "%s" field.' % + (message_descriptor.name, field_name)) + + if field in self._fields: + # To match the C++ implementation, we need to invalidate iterators + # for map fields when ClearField() happens. + if hasattr(self._fields[field], 'InvalidateIterators'): + self._fields[field].InvalidateIterators() + + # Note: If the field is a sub-message, its listener will still point + # at us. That's fine, because the worst than can happen is that it + # will call _Modified() and invalidate our byte size. Big deal. + del self._fields[field] + + if self._oneofs.get(field.containing_oneof, None) is field: + del self._oneofs[field.containing_oneof] + + # Always call _Modified() -- even if nothing was changed, this is + # a mutating method, and thus calling it should cause the field to become + # present in the parent message. + self._Modified() + + cls.ClearField = ClearField + + +def _AddClearExtensionMethod(cls): + """Helper for _AddMessageMethods().""" + def ClearExtension(self, extension_handle): + extension_dict._VerifyExtensionHandle(self, extension_handle) + + # Similar to ClearField(), above. + if extension_handle in self._fields: + del self._fields[extension_handle] + self._Modified() + cls.ClearExtension = ClearExtension + + +def _AddHasExtensionMethod(cls): + """Helper for _AddMessageMethods().""" + def HasExtension(self, extension_handle): + extension_dict._VerifyExtensionHandle(self, extension_handle) + if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: + raise KeyError('"%s" is repeated.' % extension_handle.full_name) + + if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + value = self._fields.get(extension_handle) + return value is not None and value._is_present_in_parent + else: + return extension_handle in self._fields + cls.HasExtension = HasExtension + +def _InternalUnpackAny(msg): + """Unpacks Any message and returns the unpacked message. + + This internal method is different from public Any Unpack method which takes + the target message as argument. _InternalUnpackAny method does not have + target message type and need to find the message type in descriptor pool. + + Args: + msg: An Any message to be unpacked. + + Returns: + The unpacked message. + """ + # TODO(amauryfa): Don't use the factory of generated messages. + # To make Any work with custom factories, use the message factory of the + # parent message. + # pylint: disable=g-import-not-at-top + from google.protobuf import symbol_database + factory = symbol_database.Default() + + type_url = msg.type_url + + if not type_url: + return None + + # TODO(haberman): For now we just strip the hostname. Better logic will be + # required. + type_name = type_url.split('/')[-1] + descriptor = factory.pool.FindMessageTypeByName(type_name) + + if descriptor is None: + return None + + message_class = factory.GetPrototype(descriptor) + message = message_class() + + message.ParseFromString(msg.value) + return message + + +def _AddEqualsMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def __eq__(self, other): + if (not isinstance(other, message_mod.Message) or + other.DESCRIPTOR != self.DESCRIPTOR): + return False + + if self is other: + return True + + if self.DESCRIPTOR.full_name == _AnyFullTypeName: + any_a = _InternalUnpackAny(self) + any_b = _InternalUnpackAny(other) + if any_a and any_b: + return any_a == any_b + + if not self.ListFields() == other.ListFields(): + return False + + # TODO(jieluo): Fix UnknownFieldSet to consider MessageSet extensions, + # then use it for the comparison. + unknown_fields = list(self._unknown_fields) + unknown_fields.sort() + other_unknown_fields = list(other._unknown_fields) + other_unknown_fields.sort() + return unknown_fields == other_unknown_fields + + cls.__eq__ = __eq__ + + +def _AddStrMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def __str__(self): + return text_format.MessageToString(self) + cls.__str__ = __str__ + + +def _AddReprMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def __repr__(self): + return text_format.MessageToString(self) + cls.__repr__ = __repr__ + + +def _AddUnicodeMethod(unused_message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def __unicode__(self): + return text_format.MessageToString(self, as_utf8=True).decode('utf-8') + cls.__unicode__ = __unicode__ + + +def _BytesForNonRepeatedElement(value, field_number, field_type): + """Returns the number of bytes needed to serialize a non-repeated element. + The returned byte count includes space for tag information and any + other additional space associated with serializing value. + + Args: + value: Value we're serializing. + field_number: Field number of this value. (Since the field number + is stored as part of a varint-encoded tag, this has an impact + on the total bytes required to serialize the value). + field_type: The type of the field. One of the TYPE_* constants + within FieldDescriptor. + """ + try: + fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] + return fn(field_number, value) + except KeyError: + raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) + + +def _AddByteSizeMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def ByteSize(self): + if not self._cached_byte_size_dirty: + return self._cached_byte_size + + size = 0 + descriptor = self.DESCRIPTOR + if descriptor.GetOptions().map_entry: + # Fields of map entry should always be serialized. + size = descriptor.fields_by_name['key']._sizer(self.key) + size += descriptor.fields_by_name['value']._sizer(self.value) + else: + for field_descriptor, field_value in self.ListFields(): + size += field_descriptor._sizer(field_value) + for tag_bytes, value_bytes in self._unknown_fields: + size += len(tag_bytes) + len(value_bytes) + + self._cached_byte_size = size + self._cached_byte_size_dirty = False + self._listener_for_children.dirty = False + return size + + cls.ByteSize = ByteSize + + +def _AddSerializeToStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def SerializeToString(self, **kwargs): + # Check if the message has all of its required fields set. + if not self.IsInitialized(): + raise message_mod.EncodeError( + 'Message %s is missing required fields: %s' % ( + self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) + return self.SerializePartialToString(**kwargs) + cls.SerializeToString = SerializeToString + + +def _AddSerializePartialToStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def SerializePartialToString(self, **kwargs): + out = BytesIO() + self._InternalSerialize(out.write, **kwargs) + return out.getvalue() + cls.SerializePartialToString = SerializePartialToString + + def InternalSerialize(self, write_bytes, deterministic=None): + if deterministic is None: + deterministic = ( + api_implementation.IsPythonDefaultSerializationDeterministic()) + else: + deterministic = bool(deterministic) + + descriptor = self.DESCRIPTOR + if descriptor.GetOptions().map_entry: + # Fields of map entry should always be serialized. + descriptor.fields_by_name['key']._encoder( + write_bytes, self.key, deterministic) + descriptor.fields_by_name['value']._encoder( + write_bytes, self.value, deterministic) + else: + for field_descriptor, field_value in self.ListFields(): + field_descriptor._encoder(write_bytes, field_value, deterministic) + for tag_bytes, value_bytes in self._unknown_fields: + write_bytes(tag_bytes) + write_bytes(value_bytes) + cls._InternalSerialize = InternalSerialize + + +def _AddMergeFromStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def MergeFromString(self, serialized): + serialized = memoryview(serialized) + length = len(serialized) + try: + if self._InternalParse(serialized, 0, length) != length: + # The only reason _InternalParse would return early is if it + # encountered an end-group tag. + raise message_mod.DecodeError('Unexpected end-group tag.') + except (IndexError, TypeError): + # Now ord(buf[p:p+1]) == ord('') gets TypeError. + raise message_mod.DecodeError('Truncated message.') + except struct.error as e: + raise message_mod.DecodeError(e) + return length # Return this for legacy reasons. + cls.MergeFromString = MergeFromString + + local_ReadTag = decoder.ReadTag + local_SkipField = decoder.SkipField + decoders_by_tag = cls._decoders_by_tag + + def InternalParse(self, buffer, pos, end): + """Create a message from serialized bytes. + + Args: + self: Message, instance of the proto message object. + buffer: memoryview of the serialized data. + pos: int, position to start in the serialized data. + end: int, end position of the serialized data. + + Returns: + Message object. + """ + # Guard against internal misuse, since this function is called internally + # quite extensively, and its easy to accidentally pass bytes. + assert isinstance(buffer, memoryview) + self._Modified() + field_dict = self._fields + # pylint: disable=protected-access + unknown_field_set = self._unknown_field_set + while pos != end: + (tag_bytes, new_pos) = local_ReadTag(buffer, pos) + field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) + if field_decoder is None: + if not self._unknown_fields: # pylint: disable=protected-access + self._unknown_fields = [] # pylint: disable=protected-access + if unknown_field_set is None: + # pylint: disable=protected-access + self._unknown_field_set = containers.UnknownFieldSet() + # pylint: disable=protected-access + unknown_field_set = self._unknown_field_set + # pylint: disable=protected-access + (tag, _) = decoder._DecodeVarint(tag_bytes, 0) + field_number, wire_type = wire_format.UnpackTag(tag) + if field_number == 0: + raise message_mod.DecodeError('Field number 0 is illegal.') + # TODO(jieluo): remove old_pos. + old_pos = new_pos + (data, new_pos) = decoder._DecodeUnknownField( + buffer, new_pos, wire_type) # pylint: disable=protected-access + if new_pos == -1: + return pos + # pylint: disable=protected-access + unknown_field_set._add(field_number, wire_type, data) + # TODO(jieluo): remove _unknown_fields. + new_pos = local_SkipField(buffer, old_pos, end, tag_bytes) + if new_pos == -1: + return pos + self._unknown_fields.append( + (tag_bytes, buffer[old_pos:new_pos].tobytes())) + pos = new_pos + else: + pos = field_decoder(buffer, new_pos, end, self, field_dict) + if field_desc: + self._UpdateOneofState(field_desc) + return pos + cls._InternalParse = InternalParse + + +def _AddIsInitializedMethod(message_descriptor, cls): + """Adds the IsInitialized and FindInitializationError methods to the + protocol message class.""" + + required_fields = [field for field in message_descriptor.fields + if field.label == _FieldDescriptor.LABEL_REQUIRED] + + def IsInitialized(self, errors=None): + """Checks if all required fields of a message are set. + + Args: + errors: A list which, if provided, will be populated with the field + paths of all missing required fields. + + Returns: + True iff the specified message has all required fields set. + """ + + # Performance is critical so we avoid HasField() and ListFields(). + + for field in required_fields: + if (field not in self._fields or + (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and + not self._fields[field]._is_present_in_parent)): + if errors is not None: + errors.extend(self.FindInitializationErrors()) + return False + + for field, value in list(self._fields.items()): # dict can change size! + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if field.label == _FieldDescriptor.LABEL_REPEATED: + if (field.message_type.has_options and + field.message_type.GetOptions().map_entry): + continue + for element in value: + if not element.IsInitialized(): + if errors is not None: + errors.extend(self.FindInitializationErrors()) + return False + elif value._is_present_in_parent and not value.IsInitialized(): + if errors is not None: + errors.extend(self.FindInitializationErrors()) + return False + + return True + + cls.IsInitialized = IsInitialized + + def FindInitializationErrors(self): + """Finds required fields which are not initialized. + + Returns: + A list of strings. Each string is a path to an uninitialized field from + the top-level message, e.g. "foo.bar[5].baz". + """ + + errors = [] # simplify things + + for field in required_fields: + if not self.HasField(field.name): + errors.append(field.name) + + for field, value in self.ListFields(): + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if field.is_extension: + name = '(%s)' % field.full_name + else: + name = field.name + + if _IsMapField(field): + if _IsMessageMapField(field): + for key in value: + element = value[key] + prefix = '%s[%s].' % (name, key) + sub_errors = element.FindInitializationErrors() + errors += [prefix + error for error in sub_errors] + else: + # ScalarMaps can't have any initialization errors. + pass + elif field.label == _FieldDescriptor.LABEL_REPEATED: + for i in range(len(value)): + element = value[i] + prefix = '%s[%d].' % (name, i) + sub_errors = element.FindInitializationErrors() + errors += [prefix + error for error in sub_errors] + else: + prefix = name + '.' + sub_errors = value.FindInitializationErrors() + errors += [prefix + error for error in sub_errors] + + return errors + + cls.FindInitializationErrors = FindInitializationErrors + + +def _AddMergeFromMethod(cls): + LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED + CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE + + def MergeFrom(self, msg): + if not isinstance(msg, cls): + raise TypeError( + 'Parameter to MergeFrom() must be instance of same class: ' + 'expected %s got %s.' % (cls.__name__, msg.__class__.__name__)) + + assert msg is not self + self._Modified() + + fields = self._fields + + for field, value in msg._fields.items(): + if field.label == LABEL_REPEATED: + field_value = fields.get(field) + if field_value is None: + # Construct a new object to represent this field. + field_value = field._default_constructor(self) + fields[field] = field_value + field_value.MergeFrom(value) + elif field.cpp_type == CPPTYPE_MESSAGE: + if value._is_present_in_parent: + field_value = fields.get(field) + if field_value is None: + # Construct a new object to represent this field. + field_value = field._default_constructor(self) + fields[field] = field_value + field_value.MergeFrom(value) + else: + self._fields[field] = value + if field.containing_oneof: + self._UpdateOneofState(field) + + if msg._unknown_fields: + if not self._unknown_fields: + self._unknown_fields = [] + self._unknown_fields.extend(msg._unknown_fields) + # pylint: disable=protected-access + if self._unknown_field_set is None: + self._unknown_field_set = containers.UnknownFieldSet() + self._unknown_field_set._extend(msg._unknown_field_set) + + cls.MergeFrom = MergeFrom + + +def _AddWhichOneofMethod(message_descriptor, cls): + def WhichOneof(self, oneof_name): + """Returns the name of the currently set field inside a oneof, or None.""" + try: + field = message_descriptor.oneofs_by_name[oneof_name] + except KeyError: + raise ValueError( + 'Protocol message has no oneof "%s" field.' % oneof_name) + + nested_field = self._oneofs.get(field, None) + if nested_field is not None and self.HasField(nested_field.name): + return nested_field.name + else: + return None + + cls.WhichOneof = WhichOneof + + +def _Clear(self): + # Clear fields. + self._fields = {} + self._unknown_fields = () + # pylint: disable=protected-access + if self._unknown_field_set is not None: + self._unknown_field_set._clear() + self._unknown_field_set = None + + self._oneofs = {} + self._Modified() + + +def _UnknownFields(self): + if self._unknown_field_set is None: # pylint: disable=protected-access + # pylint: disable=protected-access + self._unknown_field_set = containers.UnknownFieldSet() + return self._unknown_field_set # pylint: disable=protected-access + + +def _DiscardUnknownFields(self): + self._unknown_fields = [] + self._unknown_field_set = None # pylint: disable=protected-access + for field, value in self.ListFields(): + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if _IsMapField(field): + if _IsMessageMapField(field): + for key in value: + value[key].DiscardUnknownFields() + elif field.label == _FieldDescriptor.LABEL_REPEATED: + for sub_message in value: + sub_message.DiscardUnknownFields() + else: + value.DiscardUnknownFields() + + +def _SetListener(self, listener): + if listener is None: + self._listener = message_listener_mod.NullMessageListener() + else: + self._listener = listener + + +def _AddMessageMethods(message_descriptor, cls): + """Adds implementations of all Message methods to cls.""" + _AddListFieldsMethod(message_descriptor, cls) + _AddHasFieldMethod(message_descriptor, cls) + _AddClearFieldMethod(message_descriptor, cls) + if message_descriptor.is_extendable: + _AddClearExtensionMethod(cls) + _AddHasExtensionMethod(cls) + _AddEqualsMethod(message_descriptor, cls) + _AddStrMethod(message_descriptor, cls) + _AddReprMethod(message_descriptor, cls) + _AddUnicodeMethod(message_descriptor, cls) + _AddByteSizeMethod(message_descriptor, cls) + _AddSerializeToStringMethod(message_descriptor, cls) + _AddSerializePartialToStringMethod(message_descriptor, cls) + _AddMergeFromStringMethod(message_descriptor, cls) + _AddIsInitializedMethod(message_descriptor, cls) + _AddMergeFromMethod(cls) + _AddWhichOneofMethod(message_descriptor, cls) + # Adds methods which do not depend on cls. + cls.Clear = _Clear + cls.UnknownFields = _UnknownFields + cls.DiscardUnknownFields = _DiscardUnknownFields + cls._SetListener = _SetListener + + +def _AddPrivateHelperMethods(message_descriptor, cls): + """Adds implementation of private helper methods to cls.""" + + def Modified(self): + """Sets the _cached_byte_size_dirty bit to true, + and propagates this to our listener iff this was a state change. + """ + + # Note: Some callers check _cached_byte_size_dirty before calling + # _Modified() as an extra optimization. So, if this method is ever + # changed such that it does stuff even when _cached_byte_size_dirty is + # already true, the callers need to be updated. + if not self._cached_byte_size_dirty: + self._cached_byte_size_dirty = True + self._listener_for_children.dirty = True + self._is_present_in_parent = True + self._listener.Modified() + + def _UpdateOneofState(self, field): + """Sets field as the active field in its containing oneof. + + Will also delete currently active field in the oneof, if it is different + from the argument. Does not mark the message as modified. + """ + other_field = self._oneofs.setdefault(field.containing_oneof, field) + if other_field is not field: + del self._fields[other_field] + self._oneofs[field.containing_oneof] = field + + cls._Modified = Modified + cls.SetInParent = Modified + cls._UpdateOneofState = _UpdateOneofState + + +class _Listener(object): + + """MessageListener implementation that a parent message registers with its + child message. + + In order to support semantics like: + + foo.bar.baz.qux = 23 + assert foo.HasField('bar') + + ...child objects must have back references to their parents. + This helper class is at the heart of this support. + """ + + def __init__(self, parent_message): + """Args: + parent_message: The message whose _Modified() method we should call when + we receive Modified() messages. + """ + # This listener establishes a back reference from a child (contained) object + # to its parent (containing) object. We make this a weak reference to avoid + # creating cyclic garbage when the client finishes with the 'parent' object + # in the tree. + if isinstance(parent_message, weakref.ProxyType): + self._parent_message_weakref = parent_message + else: + self._parent_message_weakref = weakref.proxy(parent_message) + + # As an optimization, we also indicate directly on the listener whether + # or not the parent message is dirty. This way we can avoid traversing + # up the tree in the common case. + self.dirty = False + + def Modified(self): + if self.dirty: + return + try: + # Propagate the signal to our parents iff this is the first field set. + self._parent_message_weakref._Modified() + except ReferenceError: + # We can get here if a client has kept a reference to a child object, + # and is now setting a field on it, but the child's parent has been + # garbage-collected. This is not an error. + pass + + +class _OneofListener(_Listener): + """Special listener implementation for setting composite oneof fields.""" + + def __init__(self, parent_message, field): + """Args: + parent_message: The message whose _Modified() method we should call when + we receive Modified() messages. + field: The descriptor of the field being set in the parent message. + """ + super(_OneofListener, self).__init__(parent_message) + self._field = field + + def Modified(self): + """Also updates the state of the containing oneof in the parent message.""" + try: + self._parent_message_weakref._UpdateOneofState(self._field) + super(_OneofListener, self).Modified() + except ReferenceError: + pass diff --git a/contrib/python/protobuf/py3/google/protobuf/internal/type_checkers.py b/contrib/python/protobuf/py3/google/protobuf/internal/type_checkers.py new file mode 100644 index 0000000000..eb66f9f6fb --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/internal/type_checkers.py @@ -0,0 +1,426 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Provides type checking routines. + +This module defines type checking utilities in the forms of dictionaries: + +VALUE_CHECKERS: A dictionary of field types and a value validation object. +TYPE_TO_BYTE_SIZE_FN: A dictionary with field types and a size computing + function. +TYPE_TO_SERIALIZE_METHOD: A dictionary with field types and serialization + function. +FIELD_TYPE_TO_WIRE_TYPE: A dictionary with field typed and their + corresponding wire types. +TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization + function. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +try: + import ctypes +except Exception: # pylint: disable=broad-except + ctypes = None + import struct +import numbers +import six + +if six.PY3: + long = int + +from google.protobuf.internal import api_implementation +from google.protobuf.internal import decoder +from google.protobuf.internal import encoder +from google.protobuf.internal import wire_format +from google.protobuf import descriptor + +_FieldDescriptor = descriptor.FieldDescriptor + + +def TruncateToFourByteFloat(original): + if ctypes: + return ctypes.c_float(original).value + else: + return struct.unpack('<f', struct.pack('<f', original))[0] + + +def ToShortestFloat(original): + """Returns the shortest float that has same value in wire.""" + # All 4 byte floats have between 6 and 9 significant digits, so we + # start with 6 as the lower bound. + # It has to be iterative because use '.9g' directly can not get rid + # of the noises for most values. For example if set a float_field=0.9 + # use '.9g' will print 0.899999976. + precision = 6 + rounded = float('{0:.{1}g}'.format(original, precision)) + while TruncateToFourByteFloat(rounded) != original: + precision += 1 + rounded = float('{0:.{1}g}'.format(original, precision)) + return rounded + + +def SupportsOpenEnums(field_descriptor): + return field_descriptor.containing_type.syntax == "proto3" + +def GetTypeChecker(field): + """Returns a type checker for a message field of the specified types. + + Args: + field: FieldDescriptor object for this field. + + Returns: + An instance of TypeChecker which can be used to verify the types + of values assigned to a field of the specified type. + """ + if (field.cpp_type == _FieldDescriptor.CPPTYPE_STRING and + field.type == _FieldDescriptor.TYPE_STRING): + return UnicodeValueChecker() + if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: + if SupportsOpenEnums(field): + # When open enums are supported, any int32 can be assigned. + return _VALUE_CHECKERS[_FieldDescriptor.CPPTYPE_INT32] + else: + return EnumValueChecker(field.enum_type) + return _VALUE_CHECKERS[field.cpp_type] + + +# None of the typecheckers below make any attempt to guard against people +# subclassing builtin types and doing weird things. We're not trying to +# protect against malicious clients here, just people accidentally shooting +# themselves in the foot in obvious ways. + +class TypeChecker(object): + + """Type checker used to catch type errors as early as possible + when the client is setting scalar fields in protocol messages. + """ + + def __init__(self, *acceptable_types): + self._acceptable_types = acceptable_types + + def CheckValue(self, proposed_value): + """Type check the provided value and return it. + + The returned value might have been normalized to another type. + """ + if not isinstance(proposed_value, self._acceptable_types): + message = ('%.1024r has type %s, but expected one of: %s' % + (proposed_value, type(proposed_value), self._acceptable_types)) + raise TypeError(message) + # Some field types(float, double and bool) accept other types, must + # convert to the correct type in such cases. + if self._acceptable_types: + if self._acceptable_types[0] in (bool, float): + return self._acceptable_types[0](proposed_value) + return proposed_value + + +class TypeCheckerWithDefault(TypeChecker): + + def __init__(self, default_value, *acceptable_types): + TypeChecker.__init__(self, *acceptable_types) + self._default_value = default_value + + def DefaultValue(self): + return self._default_value + + +# IntValueChecker and its subclasses perform integer type-checks +# and bounds-checks. +class IntValueChecker(object): + + """Checker used for integer fields. Performs type-check and range check.""" + + def CheckValue(self, proposed_value): + if not isinstance(proposed_value, numbers.Integral): + message = ('%.1024r has type %s, but expected one of: %s' % + (proposed_value, type(proposed_value), six.integer_types)) + raise TypeError(message) + if not self._MIN <= int(proposed_value) <= self._MAX: + raise ValueError('Value out of range: %d' % proposed_value) + # We force 32-bit values to int and 64-bit values to long to make + # alternate implementations where the distinction is more significant + # (e.g. the C++ implementation) simpler. + proposed_value = self._TYPE(proposed_value) + return proposed_value + + def DefaultValue(self): + return 0 + + +class EnumValueChecker(object): + + """Checker used for enum fields. Performs type-check and range check.""" + + def __init__(self, enum_type): + self._enum_type = enum_type + + def CheckValue(self, proposed_value): + if not isinstance(proposed_value, numbers.Integral): + message = ('%.1024r has type %s, but expected one of: %s' % + (proposed_value, type(proposed_value), six.integer_types)) + raise TypeError(message) + if int(proposed_value) not in self._enum_type.values_by_number: + raise ValueError('Unknown enum value: %d' % proposed_value) + return proposed_value + + def DefaultValue(self): + return self._enum_type.values[0].number + + +class UnicodeValueChecker(object): + + """Checker used for string fields. + + Always returns a unicode value, even if the input is of type str. + """ + + def CheckValue(self, proposed_value): + if not isinstance(proposed_value, (bytes, six.text_type)): + message = ('%.1024r has type %s, but expected one of: %s' % + (proposed_value, type(proposed_value), (bytes, six.text_type))) + raise TypeError(message) + + # If the value is of type 'bytes' make sure that it is valid UTF-8 data. + if isinstance(proposed_value, bytes): + try: + proposed_value = proposed_value.decode('utf-8') + except UnicodeDecodeError: + raise ValueError('%.1024r has type bytes, but isn\'t valid UTF-8 ' + 'encoding. Non-UTF-8 strings must be converted to ' + 'unicode objects before being added.' % + (proposed_value)) + else: + try: + proposed_value.encode('utf8') + except UnicodeEncodeError: + raise ValueError('%.1024r isn\'t a valid unicode string and ' + 'can\'t be encoded in UTF-8.'% + (proposed_value)) + + return proposed_value + + def DefaultValue(self): + return u"" + + +class Int32ValueChecker(IntValueChecker): + # We're sure to use ints instead of longs here since comparison may be more + # efficient. + _MIN = -2147483648 + _MAX = 2147483647 + _TYPE = int + + +class Uint32ValueChecker(IntValueChecker): + _MIN = 0 + _MAX = (1 << 32) - 1 + _TYPE = int + + +class Int64ValueChecker(IntValueChecker): + _MIN = -(1 << 63) + _MAX = (1 << 63) - 1 + _TYPE = long + + +class Uint64ValueChecker(IntValueChecker): + _MIN = 0 + _MAX = (1 << 64) - 1 + _TYPE = long + + +# The max 4 bytes float is about 3.4028234663852886e+38 +_FLOAT_MAX = float.fromhex('0x1.fffffep+127') +_FLOAT_MIN = -_FLOAT_MAX +_INF = float('inf') +_NEG_INF = float('-inf') + + +class FloatValueChecker(object): + + """Checker used for float fields. Performs type-check and range check. + + Values exceeding a 32-bit float will be converted to inf/-inf. + """ + + def CheckValue(self, proposed_value): + """Check and convert proposed_value to float.""" + if not isinstance(proposed_value, numbers.Real): + message = ('%.1024r has type %s, but expected one of: numbers.Real' % + (proposed_value, type(proposed_value))) + raise TypeError(message) + converted_value = float(proposed_value) + # This inf rounding matches the C++ proto SafeDoubleToFloat logic. + if converted_value > _FLOAT_MAX: + return _INF + if converted_value < _FLOAT_MIN: + return _NEG_INF + + return TruncateToFourByteFloat(converted_value) + + def DefaultValue(self): + return 0.0 + + +# Type-checkers for all scalar CPPTYPEs. +_VALUE_CHECKERS = { + _FieldDescriptor.CPPTYPE_INT32: Int32ValueChecker(), + _FieldDescriptor.CPPTYPE_INT64: Int64ValueChecker(), + _FieldDescriptor.CPPTYPE_UINT32: Uint32ValueChecker(), + _FieldDescriptor.CPPTYPE_UINT64: Uint64ValueChecker(), + _FieldDescriptor.CPPTYPE_DOUBLE: TypeCheckerWithDefault( + 0.0, float, numbers.Real), + _FieldDescriptor.CPPTYPE_FLOAT: FloatValueChecker(), + _FieldDescriptor.CPPTYPE_BOOL: TypeCheckerWithDefault( + False, bool, numbers.Integral), + _FieldDescriptor.CPPTYPE_STRING: TypeCheckerWithDefault(b'', bytes), + } + + +# Map from field type to a function F, such that F(field_num, value) +# gives the total byte size for a value of the given type. This +# byte size includes tag information and any other additional space +# associated with serializing "value". +TYPE_TO_BYTE_SIZE_FN = { + _FieldDescriptor.TYPE_DOUBLE: wire_format.DoubleByteSize, + _FieldDescriptor.TYPE_FLOAT: wire_format.FloatByteSize, + _FieldDescriptor.TYPE_INT64: wire_format.Int64ByteSize, + _FieldDescriptor.TYPE_UINT64: wire_format.UInt64ByteSize, + _FieldDescriptor.TYPE_INT32: wire_format.Int32ByteSize, + _FieldDescriptor.TYPE_FIXED64: wire_format.Fixed64ByteSize, + _FieldDescriptor.TYPE_FIXED32: wire_format.Fixed32ByteSize, + _FieldDescriptor.TYPE_BOOL: wire_format.BoolByteSize, + _FieldDescriptor.TYPE_STRING: wire_format.StringByteSize, + _FieldDescriptor.TYPE_GROUP: wire_format.GroupByteSize, + _FieldDescriptor.TYPE_MESSAGE: wire_format.MessageByteSize, + _FieldDescriptor.TYPE_BYTES: wire_format.BytesByteSize, + _FieldDescriptor.TYPE_UINT32: wire_format.UInt32ByteSize, + _FieldDescriptor.TYPE_ENUM: wire_format.EnumByteSize, + _FieldDescriptor.TYPE_SFIXED32: wire_format.SFixed32ByteSize, + _FieldDescriptor.TYPE_SFIXED64: wire_format.SFixed64ByteSize, + _FieldDescriptor.TYPE_SINT32: wire_format.SInt32ByteSize, + _FieldDescriptor.TYPE_SINT64: wire_format.SInt64ByteSize + } + + +# Maps from field types to encoder constructors. +TYPE_TO_ENCODER = { + _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleEncoder, + _FieldDescriptor.TYPE_FLOAT: encoder.FloatEncoder, + _FieldDescriptor.TYPE_INT64: encoder.Int64Encoder, + _FieldDescriptor.TYPE_UINT64: encoder.UInt64Encoder, + _FieldDescriptor.TYPE_INT32: encoder.Int32Encoder, + _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Encoder, + _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Encoder, + _FieldDescriptor.TYPE_BOOL: encoder.BoolEncoder, + _FieldDescriptor.TYPE_STRING: encoder.StringEncoder, + _FieldDescriptor.TYPE_GROUP: encoder.GroupEncoder, + _FieldDescriptor.TYPE_MESSAGE: encoder.MessageEncoder, + _FieldDescriptor.TYPE_BYTES: encoder.BytesEncoder, + _FieldDescriptor.TYPE_UINT32: encoder.UInt32Encoder, + _FieldDescriptor.TYPE_ENUM: encoder.EnumEncoder, + _FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Encoder, + _FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Encoder, + _FieldDescriptor.TYPE_SINT32: encoder.SInt32Encoder, + _FieldDescriptor.TYPE_SINT64: encoder.SInt64Encoder, + } + + +# Maps from field types to sizer constructors. +TYPE_TO_SIZER = { + _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleSizer, + _FieldDescriptor.TYPE_FLOAT: encoder.FloatSizer, + _FieldDescriptor.TYPE_INT64: encoder.Int64Sizer, + _FieldDescriptor.TYPE_UINT64: encoder.UInt64Sizer, + _FieldDescriptor.TYPE_INT32: encoder.Int32Sizer, + _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Sizer, + _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Sizer, + _FieldDescriptor.TYPE_BOOL: encoder.BoolSizer, + _FieldDescriptor.TYPE_STRING: encoder.StringSizer, + _FieldDescriptor.TYPE_GROUP: encoder.GroupSizer, + _FieldDescriptor.TYPE_MESSAGE: encoder.MessageSizer, + _FieldDescriptor.TYPE_BYTES: encoder.BytesSizer, + _FieldDescriptor.TYPE_UINT32: encoder.UInt32Sizer, + _FieldDescriptor.TYPE_ENUM: encoder.EnumSizer, + _FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Sizer, + _FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Sizer, + _FieldDescriptor.TYPE_SINT32: encoder.SInt32Sizer, + _FieldDescriptor.TYPE_SINT64: encoder.SInt64Sizer, + } + + +# Maps from field type to a decoder constructor. +TYPE_TO_DECODER = { + _FieldDescriptor.TYPE_DOUBLE: decoder.DoubleDecoder, + _FieldDescriptor.TYPE_FLOAT: decoder.FloatDecoder, + _FieldDescriptor.TYPE_INT64: decoder.Int64Decoder, + _FieldDescriptor.TYPE_UINT64: decoder.UInt64Decoder, + _FieldDescriptor.TYPE_INT32: decoder.Int32Decoder, + _FieldDescriptor.TYPE_FIXED64: decoder.Fixed64Decoder, + _FieldDescriptor.TYPE_FIXED32: decoder.Fixed32Decoder, + _FieldDescriptor.TYPE_BOOL: decoder.BoolDecoder, + _FieldDescriptor.TYPE_STRING: decoder.StringDecoder, + _FieldDescriptor.TYPE_GROUP: decoder.GroupDecoder, + _FieldDescriptor.TYPE_MESSAGE: decoder.MessageDecoder, + _FieldDescriptor.TYPE_BYTES: decoder.BytesDecoder, + _FieldDescriptor.TYPE_UINT32: decoder.UInt32Decoder, + _FieldDescriptor.TYPE_ENUM: decoder.EnumDecoder, + _FieldDescriptor.TYPE_SFIXED32: decoder.SFixed32Decoder, + _FieldDescriptor.TYPE_SFIXED64: decoder.SFixed64Decoder, + _FieldDescriptor.TYPE_SINT32: decoder.SInt32Decoder, + _FieldDescriptor.TYPE_SINT64: decoder.SInt64Decoder, + } + +# Maps from field type to expected wiretype. +FIELD_TYPE_TO_WIRE_TYPE = { + _FieldDescriptor.TYPE_DOUBLE: wire_format.WIRETYPE_FIXED64, + _FieldDescriptor.TYPE_FLOAT: wire_format.WIRETYPE_FIXED32, + _FieldDescriptor.TYPE_INT64: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_UINT64: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_INT32: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_FIXED64: wire_format.WIRETYPE_FIXED64, + _FieldDescriptor.TYPE_FIXED32: wire_format.WIRETYPE_FIXED32, + _FieldDescriptor.TYPE_BOOL: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_STRING: + wire_format.WIRETYPE_LENGTH_DELIMITED, + _FieldDescriptor.TYPE_GROUP: wire_format.WIRETYPE_START_GROUP, + _FieldDescriptor.TYPE_MESSAGE: + wire_format.WIRETYPE_LENGTH_DELIMITED, + _FieldDescriptor.TYPE_BYTES: + wire_format.WIRETYPE_LENGTH_DELIMITED, + _FieldDescriptor.TYPE_UINT32: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_ENUM: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_SFIXED32: wire_format.WIRETYPE_FIXED32, + _FieldDescriptor.TYPE_SFIXED64: wire_format.WIRETYPE_FIXED64, + _FieldDescriptor.TYPE_SINT32: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_SINT64: wire_format.WIRETYPE_VARINT, + } diff --git a/contrib/python/protobuf/py3/google/protobuf/internal/well_known_types.py b/contrib/python/protobuf/py3/google/protobuf/internal/well_known_types.py new file mode 100644 index 0000000000..6f55d6b17b --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/internal/well_known_types.py @@ -0,0 +1,863 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Contains well known classes. + +This files defines well known classes which need extra maintenance including: + - Any + - Duration + - FieldMask + - Struct + - Timestamp +""" + +__author__ = 'jieluo@google.com (Jie Luo)' + +import calendar +from datetime import datetime +from datetime import timedelta +import six + +try: + # Since python 3 + import collections.abc as collections_abc +except ImportError: + # Won't work after python 3.8 + import collections as collections_abc + +from google.protobuf.descriptor import FieldDescriptor + +_TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S' +_NANOS_PER_SECOND = 1000000000 +_NANOS_PER_MILLISECOND = 1000000 +_NANOS_PER_MICROSECOND = 1000 +_MILLIS_PER_SECOND = 1000 +_MICROS_PER_SECOND = 1000000 +_SECONDS_PER_DAY = 24 * 3600 +_DURATION_SECONDS_MAX = 315576000000 + + +class Any(object): + """Class for Any Message type.""" + + __slots__ = () + + def Pack(self, msg, type_url_prefix='type.googleapis.com/', + deterministic=None): + """Packs the specified message into current Any message.""" + if len(type_url_prefix) < 1 or type_url_prefix[-1] != '/': + self.type_url = '%s/%s' % (type_url_prefix, msg.DESCRIPTOR.full_name) + else: + self.type_url = '%s%s' % (type_url_prefix, msg.DESCRIPTOR.full_name) + self.value = msg.SerializeToString(deterministic=deterministic) + + def Unpack(self, msg): + """Unpacks the current Any message into specified message.""" + descriptor = msg.DESCRIPTOR + if not self.Is(descriptor): + return False + msg.ParseFromString(self.value) + return True + + def TypeName(self): + """Returns the protobuf type name of the inner message.""" + # Only last part is to be used: b/25630112 + return self.type_url.split('/')[-1] + + def Is(self, descriptor): + """Checks if this Any represents the given protobuf type.""" + return '/' in self.type_url and self.TypeName() == descriptor.full_name + + +_EPOCH_DATETIME = datetime.utcfromtimestamp(0) + + +class Timestamp(object): + """Class for Timestamp message type.""" + + __slots__ = () + + def ToJsonString(self): + """Converts Timestamp to RFC 3339 date string format. + + Returns: + A string converted from timestamp. The string is always Z-normalized + and uses 3, 6 or 9 fractional digits as required to represent the + exact time. Example of the return format: '1972-01-01T10:00:20.021Z' + """ + nanos = self.nanos % _NANOS_PER_SECOND + total_sec = self.seconds + (self.nanos - nanos) // _NANOS_PER_SECOND + seconds = total_sec % _SECONDS_PER_DAY + days = (total_sec - seconds) // _SECONDS_PER_DAY + dt = datetime(1970, 1, 1) + timedelta(days, seconds) + + result = dt.isoformat() + if (nanos % 1e9) == 0: + # If there are 0 fractional digits, the fractional + # point '.' should be omitted when serializing. + return result + 'Z' + if (nanos % 1e6) == 0: + # Serialize 3 fractional digits. + return result + '.%03dZ' % (nanos / 1e6) + if (nanos % 1e3) == 0: + # Serialize 6 fractional digits. + return result + '.%06dZ' % (nanos / 1e3) + # Serialize 9 fractional digits. + return result + '.%09dZ' % nanos + + def FromJsonString(self, value): + """Parse a RFC 3339 date string format to Timestamp. + + Args: + value: A date string. Any fractional digits (or none) and any offset are + accepted as long as they fit into nano-seconds precision. + Example of accepted format: '1972-01-01T10:00:20.021-05:00' + + Raises: + ValueError: On parsing problems. + """ + timezone_offset = value.find('Z') + if timezone_offset == -1: + timezone_offset = value.find('+') + if timezone_offset == -1: + timezone_offset = value.rfind('-') + if timezone_offset == -1: + raise ValueError( + 'Failed to parse timestamp: missing valid timezone offset.') + time_value = value[0:timezone_offset] + # Parse datetime and nanos. + point_position = time_value.find('.') + if point_position == -1: + second_value = time_value + nano_value = '' + else: + second_value = time_value[:point_position] + nano_value = time_value[point_position + 1:] + if 't' in second_value: + raise ValueError( + 'time data \'{0}\' does not match format \'%Y-%m-%dT%H:%M:%S\', ' + 'lowercase \'t\' is not accepted'.format(second_value)) + date_object = datetime.strptime(second_value, _TIMESTAMPFOMAT) + td = date_object - datetime(1970, 1, 1) + seconds = td.seconds + td.days * _SECONDS_PER_DAY + if len(nano_value) > 9: + raise ValueError( + 'Failed to parse Timestamp: nanos {0} more than ' + '9 fractional digits.'.format(nano_value)) + if nano_value: + nanos = round(float('0.' + nano_value) * 1e9) + else: + nanos = 0 + # Parse timezone offsets. + if value[timezone_offset] == 'Z': + if len(value) != timezone_offset + 1: + raise ValueError('Failed to parse timestamp: invalid trailing' + ' data {0}.'.format(value)) + else: + timezone = value[timezone_offset:] + pos = timezone.find(':') + if pos == -1: + raise ValueError( + 'Invalid timezone offset value: {0}.'.format(timezone)) + if timezone[0] == '+': + seconds -= (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60 + else: + seconds += (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60 + # Set seconds and nanos + self.seconds = int(seconds) + self.nanos = int(nanos) + + def GetCurrentTime(self): + """Get the current UTC into Timestamp.""" + self.FromDatetime(datetime.utcnow()) + + def ToNanoseconds(self): + """Converts Timestamp to nanoseconds since epoch.""" + return self.seconds * _NANOS_PER_SECOND + self.nanos + + def ToMicroseconds(self): + """Converts Timestamp to microseconds since epoch.""" + return (self.seconds * _MICROS_PER_SECOND + + self.nanos // _NANOS_PER_MICROSECOND) + + def ToMilliseconds(self): + """Converts Timestamp to milliseconds since epoch.""" + return (self.seconds * _MILLIS_PER_SECOND + + self.nanos // _NANOS_PER_MILLISECOND) + + def ToSeconds(self): + """Converts Timestamp to seconds since epoch.""" + return self.seconds + + def FromNanoseconds(self, nanos): + """Converts nanoseconds since epoch to Timestamp.""" + self.seconds = nanos // _NANOS_PER_SECOND + self.nanos = nanos % _NANOS_PER_SECOND + + def FromMicroseconds(self, micros): + """Converts microseconds since epoch to Timestamp.""" + self.seconds = micros // _MICROS_PER_SECOND + self.nanos = (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND + + def FromMilliseconds(self, millis): + """Converts milliseconds since epoch to Timestamp.""" + self.seconds = millis // _MILLIS_PER_SECOND + self.nanos = (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND + + def FromSeconds(self, seconds): + """Converts seconds since epoch to Timestamp.""" + self.seconds = seconds + self.nanos = 0 + + def ToDatetime(self): + """Converts Timestamp to datetime.""" + return _EPOCH_DATETIME + timedelta( + seconds=self.seconds, microseconds=_RoundTowardZero( + self.nanos, _NANOS_PER_MICROSECOND)) + + def FromDatetime(self, dt): + """Converts datetime to Timestamp.""" + # Using this guide: http://wiki.python.org/moin/WorkingWithTime + # And this conversion guide: http://docs.python.org/library/time.html + + # Turn the date parameter into a tuple (struct_time) that can then be + # manipulated into a long value of seconds. During the conversion from + # struct_time to long, the source date in UTC, and so it follows that the + # correct transformation is calendar.timegm() + self.seconds = calendar.timegm(dt.utctimetuple()) + self.nanos = dt.microsecond * _NANOS_PER_MICROSECOND + + +class Duration(object): + """Class for Duration message type.""" + + __slots__ = () + + def ToJsonString(self): + """Converts Duration to string format. + + Returns: + A string converted from self. The string format will contains + 3, 6, or 9 fractional digits depending on the precision required to + represent the exact Duration value. For example: "1s", "1.010s", + "1.000000100s", "-3.100s" + """ + _CheckDurationValid(self.seconds, self.nanos) + if self.seconds < 0 or self.nanos < 0: + result = '-' + seconds = - self.seconds + int((0 - self.nanos) // 1e9) + nanos = (0 - self.nanos) % 1e9 + else: + result = '' + seconds = self.seconds + int(self.nanos // 1e9) + nanos = self.nanos % 1e9 + result += '%d' % seconds + if (nanos % 1e9) == 0: + # If there are 0 fractional digits, the fractional + # point '.' should be omitted when serializing. + return result + 's' + if (nanos % 1e6) == 0: + # Serialize 3 fractional digits. + return result + '.%03ds' % (nanos / 1e6) + if (nanos % 1e3) == 0: + # Serialize 6 fractional digits. + return result + '.%06ds' % (nanos / 1e3) + # Serialize 9 fractional digits. + return result + '.%09ds' % nanos + + def FromJsonString(self, value): + """Converts a string to Duration. + + Args: + value: A string to be converted. The string must end with 's'. Any + fractional digits (or none) are accepted as long as they fit into + precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s + + Raises: + ValueError: On parsing problems. + """ + if len(value) < 1 or value[-1] != 's': + raise ValueError( + 'Duration must end with letter "s": {0}.'.format(value)) + try: + pos = value.find('.') + if pos == -1: + seconds = int(value[:-1]) + nanos = 0 + else: + seconds = int(value[:pos]) + if value[0] == '-': + nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9)) + else: + nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9)) + _CheckDurationValid(seconds, nanos) + self.seconds = seconds + self.nanos = nanos + except ValueError as e: + raise ValueError( + 'Couldn\'t parse duration: {0} : {1}.'.format(value, e)) + + def ToNanoseconds(self): + """Converts a Duration to nanoseconds.""" + return self.seconds * _NANOS_PER_SECOND + self.nanos + + def ToMicroseconds(self): + """Converts a Duration to microseconds.""" + micros = _RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND) + return self.seconds * _MICROS_PER_SECOND + micros + + def ToMilliseconds(self): + """Converts a Duration to milliseconds.""" + millis = _RoundTowardZero(self.nanos, _NANOS_PER_MILLISECOND) + return self.seconds * _MILLIS_PER_SECOND + millis + + def ToSeconds(self): + """Converts a Duration to seconds.""" + return self.seconds + + def FromNanoseconds(self, nanos): + """Converts nanoseconds to Duration.""" + self._NormalizeDuration(nanos // _NANOS_PER_SECOND, + nanos % _NANOS_PER_SECOND) + + def FromMicroseconds(self, micros): + """Converts microseconds to Duration.""" + self._NormalizeDuration( + micros // _MICROS_PER_SECOND, + (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND) + + def FromMilliseconds(self, millis): + """Converts milliseconds to Duration.""" + self._NormalizeDuration( + millis // _MILLIS_PER_SECOND, + (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND) + + def FromSeconds(self, seconds): + """Converts seconds to Duration.""" + self.seconds = seconds + self.nanos = 0 + + def ToTimedelta(self): + """Converts Duration to timedelta.""" + return timedelta( + seconds=self.seconds, microseconds=_RoundTowardZero( + self.nanos, _NANOS_PER_MICROSECOND)) + + def FromTimedelta(self, td): + """Converts timedelta to Duration.""" + self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY, + td.microseconds * _NANOS_PER_MICROSECOND) + + def _NormalizeDuration(self, seconds, nanos): + """Set Duration by seconds and nanos.""" + # Force nanos to be negative if the duration is negative. + if seconds < 0 and nanos > 0: + seconds += 1 + nanos -= _NANOS_PER_SECOND + self.seconds = seconds + self.nanos = nanos + + +def _CheckDurationValid(seconds, nanos): + if seconds < -_DURATION_SECONDS_MAX or seconds > _DURATION_SECONDS_MAX: + raise ValueError( + 'Duration is not valid: Seconds {0} must be in range ' + '[-315576000000, 315576000000].'.format(seconds)) + if nanos <= -_NANOS_PER_SECOND or nanos >= _NANOS_PER_SECOND: + raise ValueError( + 'Duration is not valid: Nanos {0} must be in range ' + '[-999999999, 999999999].'.format(nanos)) + if (nanos < 0 and seconds > 0) or (nanos > 0 and seconds < 0): + raise ValueError( + 'Duration is not valid: Sign mismatch.') + + +def _RoundTowardZero(value, divider): + """Truncates the remainder part after division.""" + # For some languages, the sign of the remainder is implementation + # dependent if any of the operands is negative. Here we enforce + # "rounded toward zero" semantics. For example, for (-5) / 2 an + # implementation may give -3 as the result with the remainder being + # 1. This function ensures we always return -2 (closer to zero). + result = value // divider + remainder = value % divider + if result < 0 and remainder > 0: + return result + 1 + else: + return result + + +class FieldMask(object): + """Class for FieldMask message type.""" + + __slots__ = () + + def ToJsonString(self): + """Converts FieldMask to string according to proto3 JSON spec.""" + camelcase_paths = [] + for path in self.paths: + camelcase_paths.append(_SnakeCaseToCamelCase(path)) + return ','.join(camelcase_paths) + + def FromJsonString(self, value): + """Converts string to FieldMask according to proto3 JSON spec.""" + self.Clear() + if value: + for path in value.split(','): + self.paths.append(_CamelCaseToSnakeCase(path)) + + def IsValidForDescriptor(self, message_descriptor): + """Checks whether the FieldMask is valid for Message Descriptor.""" + for path in self.paths: + if not _IsValidPath(message_descriptor, path): + return False + return True + + def AllFieldsFromDescriptor(self, message_descriptor): + """Gets all direct fields of Message Descriptor to FieldMask.""" + self.Clear() + for field in message_descriptor.fields: + self.paths.append(field.name) + + def CanonicalFormFromMask(self, mask): + """Converts a FieldMask to the canonical form. + + Removes paths that are covered by another path. For example, + "foo.bar" is covered by "foo" and will be removed if "foo" + is also in the FieldMask. Then sorts all paths in alphabetical order. + + Args: + mask: The original FieldMask to be converted. + """ + tree = _FieldMaskTree(mask) + tree.ToFieldMask(self) + + def Union(self, mask1, mask2): + """Merges mask1 and mask2 into this FieldMask.""" + _CheckFieldMaskMessage(mask1) + _CheckFieldMaskMessage(mask2) + tree = _FieldMaskTree(mask1) + tree.MergeFromFieldMask(mask2) + tree.ToFieldMask(self) + + def Intersect(self, mask1, mask2): + """Intersects mask1 and mask2 into this FieldMask.""" + _CheckFieldMaskMessage(mask1) + _CheckFieldMaskMessage(mask2) + tree = _FieldMaskTree(mask1) + intersection = _FieldMaskTree() + for path in mask2.paths: + tree.IntersectPath(path, intersection) + intersection.ToFieldMask(self) + + def MergeMessage( + self, source, destination, + replace_message_field=False, replace_repeated_field=False): + """Merges fields specified in FieldMask from source to destination. + + Args: + source: Source message. + destination: The destination message to be merged into. + replace_message_field: Replace message field if True. Merge message + field if False. + replace_repeated_field: Replace repeated field if True. Append + elements of repeated field if False. + """ + tree = _FieldMaskTree(self) + tree.MergeMessage( + source, destination, replace_message_field, replace_repeated_field) + + +def _IsValidPath(message_descriptor, path): + """Checks whether the path is valid for Message Descriptor.""" + parts = path.split('.') + last = parts.pop() + for name in parts: + field = message_descriptor.fields_by_name.get(name) + if (field is None or + field.label == FieldDescriptor.LABEL_REPEATED or + field.type != FieldDescriptor.TYPE_MESSAGE): + return False + message_descriptor = field.message_type + return last in message_descriptor.fields_by_name + + +def _CheckFieldMaskMessage(message): + """Raises ValueError if message is not a FieldMask.""" + message_descriptor = message.DESCRIPTOR + if (message_descriptor.name != 'FieldMask' or + message_descriptor.file.name != 'google/protobuf/field_mask.proto'): + raise ValueError('Message {0} is not a FieldMask.'.format( + message_descriptor.full_name)) + + +def _SnakeCaseToCamelCase(path_name): + """Converts a path name from snake_case to camelCase.""" + result = [] + after_underscore = False + for c in path_name: + if c.isupper(): + raise ValueError( + 'Fail to print FieldMask to Json string: Path name ' + '{0} must not contain uppercase letters.'.format(path_name)) + if after_underscore: + if c.islower(): + result.append(c.upper()) + after_underscore = False + else: + raise ValueError( + 'Fail to print FieldMask to Json string: The ' + 'character after a "_" must be a lowercase letter ' + 'in path name {0}.'.format(path_name)) + elif c == '_': + after_underscore = True + else: + result += c + + if after_underscore: + raise ValueError('Fail to print FieldMask to Json string: Trailing "_" ' + 'in path name {0}.'.format(path_name)) + return ''.join(result) + + +def _CamelCaseToSnakeCase(path_name): + """Converts a field name from camelCase to snake_case.""" + result = [] + for c in path_name: + if c == '_': + raise ValueError('Fail to parse FieldMask: Path name ' + '{0} must not contain "_"s.'.format(path_name)) + if c.isupper(): + result += '_' + result += c.lower() + else: + result += c + return ''.join(result) + + +class _FieldMaskTree(object): + """Represents a FieldMask in a tree structure. + + For example, given a FieldMask "foo.bar,foo.baz,bar.baz", + the FieldMaskTree will be: + [_root] -+- foo -+- bar + | | + | +- baz + | + +- bar --- baz + In the tree, each leaf node represents a field path. + """ + + __slots__ = ('_root',) + + def __init__(self, field_mask=None): + """Initializes the tree by FieldMask.""" + self._root = {} + if field_mask: + self.MergeFromFieldMask(field_mask) + + def MergeFromFieldMask(self, field_mask): + """Merges a FieldMask to the tree.""" + for path in field_mask.paths: + self.AddPath(path) + + def AddPath(self, path): + """Adds a field path into the tree. + + If the field path to add is a sub-path of an existing field path + in the tree (i.e., a leaf node), it means the tree already matches + the given path so nothing will be added to the tree. If the path + matches an existing non-leaf node in the tree, that non-leaf node + will be turned into a leaf node with all its children removed because + the path matches all the node's children. Otherwise, a new path will + be added. + + Args: + path: The field path to add. + """ + node = self._root + for name in path.split('.'): + if name not in node: + node[name] = {} + elif not node[name]: + # Pre-existing empty node implies we already have this entire tree. + return + node = node[name] + # Remove any sub-trees we might have had. + node.clear() + + def ToFieldMask(self, field_mask): + """Converts the tree to a FieldMask.""" + field_mask.Clear() + _AddFieldPaths(self._root, '', field_mask) + + def IntersectPath(self, path, intersection): + """Calculates the intersection part of a field path with this tree. + + Args: + path: The field path to calculates. + intersection: The out tree to record the intersection part. + """ + node = self._root + for name in path.split('.'): + if name not in node: + return + elif not node[name]: + intersection.AddPath(path) + return + node = node[name] + intersection.AddLeafNodes(path, node) + + def AddLeafNodes(self, prefix, node): + """Adds leaf nodes begin with prefix to this tree.""" + if not node: + self.AddPath(prefix) + for name in node: + child_path = prefix + '.' + name + self.AddLeafNodes(child_path, node[name]) + + def MergeMessage( + self, source, destination, + replace_message, replace_repeated): + """Merge all fields specified by this tree from source to destination.""" + _MergeMessage( + self._root, source, destination, replace_message, replace_repeated) + + +def _StrConvert(value): + """Converts value to str if it is not.""" + # This file is imported by c extension and some methods like ClearField + # requires string for the field name. py2/py3 has different text + # type and may use unicode. + if not isinstance(value, str): + return value.encode('utf-8') + return value + + +def _MergeMessage( + node, source, destination, replace_message, replace_repeated): + """Merge all fields specified by a sub-tree from source to destination.""" + source_descriptor = source.DESCRIPTOR + for name in node: + child = node[name] + field = source_descriptor.fields_by_name[name] + if field is None: + raise ValueError('Error: Can\'t find field {0} in message {1}.'.format( + name, source_descriptor.full_name)) + if child: + # Sub-paths are only allowed for singular message fields. + if (field.label == FieldDescriptor.LABEL_REPEATED or + field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE): + raise ValueError('Error: Field {0} in message {1} is not a singular ' + 'message field and cannot have sub-fields.'.format( + name, source_descriptor.full_name)) + if source.HasField(name): + _MergeMessage( + child, getattr(source, name), getattr(destination, name), + replace_message, replace_repeated) + continue + if field.label == FieldDescriptor.LABEL_REPEATED: + if replace_repeated: + destination.ClearField(_StrConvert(name)) + repeated_source = getattr(source, name) + repeated_destination = getattr(destination, name) + repeated_destination.MergeFrom(repeated_source) + else: + if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: + if replace_message: + destination.ClearField(_StrConvert(name)) + if source.HasField(name): + getattr(destination, name).MergeFrom(getattr(source, name)) + else: + setattr(destination, name, getattr(source, name)) + + +def _AddFieldPaths(node, prefix, field_mask): + """Adds the field paths descended from node to field_mask.""" + if not node and prefix: + field_mask.paths.append(prefix) + return + for name in sorted(node): + if prefix: + child_path = prefix + '.' + name + else: + child_path = name + _AddFieldPaths(node[name], child_path, field_mask) + + +_INT_OR_FLOAT = six.integer_types + (float,) + + +def _SetStructValue(struct_value, value): + if value is None: + struct_value.null_value = 0 + elif isinstance(value, bool): + # Note: this check must come before the number check because in Python + # True and False are also considered numbers. + struct_value.bool_value = value + elif isinstance(value, six.string_types): + struct_value.string_value = value + elif isinstance(value, _INT_OR_FLOAT): + struct_value.number_value = value + elif isinstance(value, (dict, Struct)): + struct_value.struct_value.Clear() + struct_value.struct_value.update(value) + elif isinstance(value, (list, ListValue)): + struct_value.list_value.Clear() + struct_value.list_value.extend(value) + else: + raise ValueError('Unexpected type') + + +def _GetStructValue(struct_value): + which = struct_value.WhichOneof('kind') + if which == 'struct_value': + return struct_value.struct_value + elif which == 'null_value': + return None + elif which == 'number_value': + return struct_value.number_value + elif which == 'string_value': + return struct_value.string_value + elif which == 'bool_value': + return struct_value.bool_value + elif which == 'list_value': + return struct_value.list_value + elif which is None: + raise ValueError('Value not set') + + +class Struct(object): + """Class for Struct message type.""" + + __slots__ = () + + def __getitem__(self, key): + return _GetStructValue(self.fields[key]) + + def __contains__(self, item): + return item in self.fields + + def __setitem__(self, key, value): + _SetStructValue(self.fields[key], value) + + def __delitem__(self, key): + del self.fields[key] + + def __len__(self): + return len(self.fields) + + def __iter__(self): + return iter(self.fields) + + def keys(self): # pylint: disable=invalid-name + return self.fields.keys() + + def values(self): # pylint: disable=invalid-name + return [self[key] for key in self] + + def items(self): # pylint: disable=invalid-name + return [(key, self[key]) for key in self] + + def get_or_create_list(self, key): + """Returns a list for this key, creating if it didn't exist already.""" + if not self.fields[key].HasField('list_value'): + # Clear will mark list_value modified which will indeed create a list. + self.fields[key].list_value.Clear() + return self.fields[key].list_value + + def get_or_create_struct(self, key): + """Returns a struct for this key, creating if it didn't exist already.""" + if not self.fields[key].HasField('struct_value'): + # Clear will mark struct_value modified which will indeed create a struct. + self.fields[key].struct_value.Clear() + return self.fields[key].struct_value + + def update(self, dictionary): # pylint: disable=invalid-name + for key, value in dictionary.items(): + _SetStructValue(self.fields[key], value) + +collections_abc.MutableMapping.register(Struct) + + +class ListValue(object): + """Class for ListValue message type.""" + + __slots__ = () + + def __len__(self): + return len(self.values) + + def append(self, value): + _SetStructValue(self.values.add(), value) + + def extend(self, elem_seq): + for value in elem_seq: + self.append(value) + + def __getitem__(self, index): + """Retrieves item by the specified index.""" + return _GetStructValue(self.values.__getitem__(index)) + + def __setitem__(self, index, value): + _SetStructValue(self.values.__getitem__(index), value) + + def __delitem__(self, key): + del self.values[key] + + def items(self): + for i in range(len(self)): + yield self[i] + + def add_struct(self): + """Appends and returns a struct value as the next value in the list.""" + struct_value = self.values.add().struct_value + # Clear will mark struct_value modified which will indeed create a struct. + struct_value.Clear() + return struct_value + + def add_list(self): + """Appends and returns a list value as the next value in the list.""" + list_value = self.values.add().list_value + # Clear will mark list_value modified which will indeed create a list. + list_value.Clear() + return list_value + +collections_abc.MutableSequence.register(ListValue) + + +WKTBASES = { + 'google.protobuf.Any': Any, + 'google.protobuf.Duration': Duration, + 'google.protobuf.FieldMask': FieldMask, + 'google.protobuf.ListValue': ListValue, + 'google.protobuf.Struct': Struct, + 'google.protobuf.Timestamp': Timestamp, +} diff --git a/contrib/python/protobuf/py3/google/protobuf/internal/wire_format.py b/contrib/python/protobuf/py3/google/protobuf/internal/wire_format.py new file mode 100644 index 0000000000..883f525585 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/internal/wire_format.py @@ -0,0 +1,268 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Constants and static functions to support protocol buffer wire format.""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import struct +from google.protobuf import descriptor +from google.protobuf import message + + +TAG_TYPE_BITS = 3 # Number of bits used to hold type info in a proto tag. +TAG_TYPE_MASK = (1 << TAG_TYPE_BITS) - 1 # 0x7 + +# These numbers identify the wire type of a protocol buffer value. +# We use the least-significant TAG_TYPE_BITS bits of the varint-encoded +# tag-and-type to store one of these WIRETYPE_* constants. +# These values must match WireType enum in google/protobuf/wire_format.h. +WIRETYPE_VARINT = 0 +WIRETYPE_FIXED64 = 1 +WIRETYPE_LENGTH_DELIMITED = 2 +WIRETYPE_START_GROUP = 3 +WIRETYPE_END_GROUP = 4 +WIRETYPE_FIXED32 = 5 +_WIRETYPE_MAX = 5 + + +# Bounds for various integer types. +INT32_MAX = int((1 << 31) - 1) +INT32_MIN = int(-(1 << 31)) +UINT32_MAX = (1 << 32) - 1 + +INT64_MAX = (1 << 63) - 1 +INT64_MIN = -(1 << 63) +UINT64_MAX = (1 << 64) - 1 + +# "struct" format strings that will encode/decode the specified formats. +FORMAT_UINT32_LITTLE_ENDIAN = '<I' +FORMAT_UINT64_LITTLE_ENDIAN = '<Q' +FORMAT_FLOAT_LITTLE_ENDIAN = '<f' +FORMAT_DOUBLE_LITTLE_ENDIAN = '<d' + + +# We'll have to provide alternate implementations of AppendLittleEndian*() on +# any architectures where these checks fail. +if struct.calcsize(FORMAT_UINT32_LITTLE_ENDIAN) != 4: + raise AssertionError('Format "I" is not a 32-bit number.') +if struct.calcsize(FORMAT_UINT64_LITTLE_ENDIAN) != 8: + raise AssertionError('Format "Q" is not a 64-bit number.') + + +def PackTag(field_number, wire_type): + """Returns an unsigned 32-bit integer that encodes the field number and + wire type information in standard protocol message wire format. + + Args: + field_number: Expected to be an integer in the range [1, 1 << 29) + wire_type: One of the WIRETYPE_* constants. + """ + if not 0 <= wire_type <= _WIRETYPE_MAX: + raise message.EncodeError('Unknown wire type: %d' % wire_type) + return (field_number << TAG_TYPE_BITS) | wire_type + + +def UnpackTag(tag): + """The inverse of PackTag(). Given an unsigned 32-bit number, + returns a (field_number, wire_type) tuple. + """ + return (tag >> TAG_TYPE_BITS), (tag & TAG_TYPE_MASK) + + +def ZigZagEncode(value): + """ZigZag Transform: Encodes signed integers so that they can be + effectively used with varint encoding. See wire_format.h for + more details. + """ + if value >= 0: + return value << 1 + return (value << 1) ^ (~0) + + +def ZigZagDecode(value): + """Inverse of ZigZagEncode().""" + if not value & 0x1: + return value >> 1 + return (value >> 1) ^ (~0) + + + +# The *ByteSize() functions below return the number of bytes required to +# serialize "field number + type" information and then serialize the value. + + +def Int32ByteSize(field_number, int32): + return Int64ByteSize(field_number, int32) + + +def Int32ByteSizeNoTag(int32): + return _VarUInt64ByteSizeNoTag(0xffffffffffffffff & int32) + + +def Int64ByteSize(field_number, int64): + # Have to convert to uint before calling UInt64ByteSize(). + return UInt64ByteSize(field_number, 0xffffffffffffffff & int64) + + +def UInt32ByteSize(field_number, uint32): + return UInt64ByteSize(field_number, uint32) + + +def UInt64ByteSize(field_number, uint64): + return TagByteSize(field_number) + _VarUInt64ByteSizeNoTag(uint64) + + +def SInt32ByteSize(field_number, int32): + return UInt32ByteSize(field_number, ZigZagEncode(int32)) + + +def SInt64ByteSize(field_number, int64): + return UInt64ByteSize(field_number, ZigZagEncode(int64)) + + +def Fixed32ByteSize(field_number, fixed32): + return TagByteSize(field_number) + 4 + + +def Fixed64ByteSize(field_number, fixed64): + return TagByteSize(field_number) + 8 + + +def SFixed32ByteSize(field_number, sfixed32): + return TagByteSize(field_number) + 4 + + +def SFixed64ByteSize(field_number, sfixed64): + return TagByteSize(field_number) + 8 + + +def FloatByteSize(field_number, flt): + return TagByteSize(field_number) + 4 + + +def DoubleByteSize(field_number, double): + return TagByteSize(field_number) + 8 + + +def BoolByteSize(field_number, b): + return TagByteSize(field_number) + 1 + + +def EnumByteSize(field_number, enum): + return UInt32ByteSize(field_number, enum) + + +def StringByteSize(field_number, string): + return BytesByteSize(field_number, string.encode('utf-8')) + + +def BytesByteSize(field_number, b): + return (TagByteSize(field_number) + + _VarUInt64ByteSizeNoTag(len(b)) + + len(b)) + + +def GroupByteSize(field_number, message): + return (2 * TagByteSize(field_number) # START and END group. + + message.ByteSize()) + + +def MessageByteSize(field_number, message): + return (TagByteSize(field_number) + + _VarUInt64ByteSizeNoTag(message.ByteSize()) + + message.ByteSize()) + + +def MessageSetItemByteSize(field_number, msg): + # First compute the sizes of the tags. + # There are 2 tags for the beginning and ending of the repeated group, that + # is field number 1, one with field number 2 (type_id) and one with field + # number 3 (message). + total_size = (2 * TagByteSize(1) + TagByteSize(2) + TagByteSize(3)) + + # Add the number of bytes for type_id. + total_size += _VarUInt64ByteSizeNoTag(field_number) + + message_size = msg.ByteSize() + + # The number of bytes for encoding the length of the message. + total_size += _VarUInt64ByteSizeNoTag(message_size) + + # The size of the message. + total_size += message_size + return total_size + + +def TagByteSize(field_number): + """Returns the bytes required to serialize a tag with this field number.""" + # Just pass in type 0, since the type won't affect the tag+type size. + return _VarUInt64ByteSizeNoTag(PackTag(field_number, 0)) + + +# Private helper function for the *ByteSize() functions above. + +def _VarUInt64ByteSizeNoTag(uint64): + """Returns the number of bytes required to serialize a single varint + using boundary value comparisons. (unrolled loop optimization -WPierce) + uint64 must be unsigned. + """ + if uint64 <= 0x7f: return 1 + if uint64 <= 0x3fff: return 2 + if uint64 <= 0x1fffff: return 3 + if uint64 <= 0xfffffff: return 4 + if uint64 <= 0x7ffffffff: return 5 + if uint64 <= 0x3ffffffffff: return 6 + if uint64 <= 0x1ffffffffffff: return 7 + if uint64 <= 0xffffffffffffff: return 8 + if uint64 <= 0x7fffffffffffffff: return 9 + if uint64 > UINT64_MAX: + raise message.EncodeError('Value out of range: %d' % uint64) + return 10 + + +NON_PACKABLE_TYPES = ( + descriptor.FieldDescriptor.TYPE_STRING, + descriptor.FieldDescriptor.TYPE_GROUP, + descriptor.FieldDescriptor.TYPE_MESSAGE, + descriptor.FieldDescriptor.TYPE_BYTES +) + + +def IsTypePackable(field_type): + """Return true iff packable = true is valid for fields of this type. + + Args: + field_type: a FieldDescriptor::Type value. + + Returns: + True iff fields of this type are packable. + """ + return field_type not in NON_PACKABLE_TYPES diff --git a/contrib/python/protobuf/py3/google/protobuf/json_format.py b/contrib/python/protobuf/py3/google/protobuf/json_format.py new file mode 100644 index 0000000000..965614d803 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/json_format.py @@ -0,0 +1,865 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Contains routines for printing protocol messages in JSON format. + +Simple usage example: + + # Create a proto object and serialize it to a json format string. + message = my_proto_pb2.MyMessage(foo='bar') + json_string = json_format.MessageToJson(message) + + # Parse a json format string to proto object. + message = json_format.Parse(json_string, my_proto_pb2.MyMessage()) +""" + +__author__ = 'jieluo@google.com (Jie Luo)' + +# pylint: disable=g-statement-before-imports,g-import-not-at-top +try: + from collections import OrderedDict +except ImportError: + from ordereddict import OrderedDict # PY26 +# pylint: enable=g-statement-before-imports,g-import-not-at-top + +import base64 +import json +import math + +from operator import methodcaller + +import re +import sys + +import six + +from google.protobuf.internal import type_checkers +from google.protobuf import descriptor +from google.protobuf import symbol_database + + +_TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S' +_INT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_INT32, + descriptor.FieldDescriptor.CPPTYPE_UINT32, + descriptor.FieldDescriptor.CPPTYPE_INT64, + descriptor.FieldDescriptor.CPPTYPE_UINT64]) +_INT64_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_INT64, + descriptor.FieldDescriptor.CPPTYPE_UINT64]) +_FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT, + descriptor.FieldDescriptor.CPPTYPE_DOUBLE]) +_INFINITY = 'Infinity' +_NEG_INFINITY = '-Infinity' +_NAN = 'NaN' + +_UNPAIRED_SURROGATE_PATTERN = re.compile(six.u( + r'[\ud800-\udbff](?![\udc00-\udfff])|(?<![\ud800-\udbff])[\udc00-\udfff]' +)) + +_VALID_EXTENSION_NAME = re.compile(r'\[[a-zA-Z0-9\._]*\]$') + + +class Error(Exception): + """Top-level module error for json_format.""" + + +class SerializeToJsonError(Error): + """Thrown if serialization to JSON fails.""" + + +class ParseError(Error): + """Thrown in case of parsing error.""" + + +def MessageToJson( + message, + including_default_value_fields=False, + preserving_proto_field_name=False, + indent=2, + sort_keys=False, + use_integers_for_enums=False, + descriptor_pool=None, + float_precision=None): + """Converts protobuf message to JSON format. + + Args: + message: The protocol buffers message instance to serialize. + including_default_value_fields: If True, singular primitive fields, + repeated fields, and map fields will always be serialized. If + False, only serialize non-empty fields. Singular message fields + and oneof fields are not affected by this option. + preserving_proto_field_name: If True, use the original proto field + names as defined in the .proto file. If False, convert the field + names to lowerCamelCase. + indent: The JSON object will be pretty-printed with this indent level. + An indent level of 0 or negative will only insert newlines. + sort_keys: If True, then the output will be sorted by field names. + use_integers_for_enums: If true, print integers instead of enum names. + descriptor_pool: A Descriptor Pool for resolving types. If None use the + default. + float_precision: If set, use this to specify float field valid digits. + + Returns: + A string containing the JSON formatted protocol buffer message. + """ + printer = _Printer( + including_default_value_fields, + preserving_proto_field_name, + use_integers_for_enums, + descriptor_pool, + float_precision=float_precision) + return printer.ToJsonString(message, indent, sort_keys) + + +def MessageToDict( + message, + including_default_value_fields=False, + preserving_proto_field_name=False, + use_integers_for_enums=False, + descriptor_pool=None, + float_precision=None): + """Converts protobuf message to a dictionary. + + When the dictionary is encoded to JSON, it conforms to proto3 JSON spec. + + Args: + message: The protocol buffers message instance to serialize. + including_default_value_fields: If True, singular primitive fields, + repeated fields, and map fields will always be serialized. If + False, only serialize non-empty fields. Singular message fields + and oneof fields are not affected by this option. + preserving_proto_field_name: If True, use the original proto field + names as defined in the .proto file. If False, convert the field + names to lowerCamelCase. + use_integers_for_enums: If true, print integers instead of enum names. + descriptor_pool: A Descriptor Pool for resolving types. If None use the + default. + float_precision: If set, use this to specify float field valid digits. + + Returns: + A dict representation of the protocol buffer message. + """ + printer = _Printer( + including_default_value_fields, + preserving_proto_field_name, + use_integers_for_enums, + descriptor_pool, + float_precision=float_precision) + # pylint: disable=protected-access + return printer._MessageToJsonObject(message) + + +def _IsMapEntry(field): + return (field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and + field.message_type.has_options and + field.message_type.GetOptions().map_entry) + + +class _Printer(object): + """JSON format printer for protocol message.""" + + def __init__( + self, + including_default_value_fields=False, + preserving_proto_field_name=False, + use_integers_for_enums=False, + descriptor_pool=None, + float_precision=None): + self.including_default_value_fields = including_default_value_fields + self.preserving_proto_field_name = preserving_proto_field_name + self.use_integers_for_enums = use_integers_for_enums + self.descriptor_pool = descriptor_pool + if float_precision: + self.float_format = '.{}g'.format(float_precision) + else: + self.float_format = None + + def ToJsonString(self, message, indent, sort_keys): + js = self._MessageToJsonObject(message) + return json.dumps(js, indent=indent, sort_keys=sort_keys) + + def _MessageToJsonObject(self, message): + """Converts message to an object according to Proto3 JSON Specification.""" + message_descriptor = message.DESCRIPTOR + full_name = message_descriptor.full_name + if _IsWrapperMessage(message_descriptor): + return self._WrapperMessageToJsonObject(message) + if full_name in _WKTJSONMETHODS: + return methodcaller(_WKTJSONMETHODS[full_name][0], message)(self) + js = {} + return self._RegularMessageToJsonObject(message, js) + + def _RegularMessageToJsonObject(self, message, js): + """Converts normal message according to Proto3 JSON Specification.""" + fields = message.ListFields() + + try: + for field, value in fields: + if self.preserving_proto_field_name: + name = field.name + else: + name = field.json_name + if _IsMapEntry(field): + # Convert a map field. + v_field = field.message_type.fields_by_name['value'] + js_map = {} + for key in value: + if isinstance(key, bool): + if key: + recorded_key = 'true' + else: + recorded_key = 'false' + else: + recorded_key = key + js_map[recorded_key] = self._FieldToJsonObject( + v_field, value[key]) + js[name] = js_map + elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + # Convert a repeated field. + js[name] = [self._FieldToJsonObject(field, k) + for k in value] + elif field.is_extension: + name = '[%s]' % field.full_name + js[name] = self._FieldToJsonObject(field, value) + else: + js[name] = self._FieldToJsonObject(field, value) + + # Serialize default value if including_default_value_fields is True. + if self.including_default_value_fields: + message_descriptor = message.DESCRIPTOR + for field in message_descriptor.fields: + # Singular message fields and oneof fields will not be affected. + if ((field.label != descriptor.FieldDescriptor.LABEL_REPEATED and + field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE) or + field.containing_oneof): + continue + if self.preserving_proto_field_name: + name = field.name + else: + name = field.json_name + if name in js: + # Skip the field which has been serialized already. + continue + if _IsMapEntry(field): + js[name] = {} + elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + js[name] = [] + else: + js[name] = self._FieldToJsonObject(field, field.default_value) + + except ValueError as e: + raise SerializeToJsonError( + 'Failed to serialize {0} field: {1}.'.format(field.name, e)) + + return js + + def _FieldToJsonObject(self, field, value): + """Converts field value according to Proto3 JSON Specification.""" + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + return self._MessageToJsonObject(value) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: + if self.use_integers_for_enums: + return value + if field.enum_type.full_name == 'google.protobuf.NullValue': + return None + enum_value = field.enum_type.values_by_number.get(value, None) + if enum_value is not None: + return enum_value.name + else: + if field.file.syntax == 'proto3': + return value + raise SerializeToJsonError('Enum field contains an integer value ' + 'which can not mapped to an enum value.') + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: + if field.type == descriptor.FieldDescriptor.TYPE_BYTES: + # Use base64 Data encoding for bytes + return base64.b64encode(value).decode('utf-8') + else: + return value + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: + return bool(value) + elif field.cpp_type in _INT64_TYPES: + return str(value) + elif field.cpp_type in _FLOAT_TYPES: + if math.isinf(value): + if value < 0.0: + return _NEG_INFINITY + else: + return _INFINITY + if math.isnan(value): + return _NAN + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_FLOAT: + if self.float_format: + return float(format(value, self.float_format)) + else: + return type_checkers.ToShortestFloat(value) + + return value + + def _AnyMessageToJsonObject(self, message): + """Converts Any message according to Proto3 JSON Specification.""" + if not message.ListFields(): + return {} + # Must print @type first, use OrderedDict instead of {} + js = OrderedDict() + type_url = message.type_url + js['@type'] = type_url + sub_message = _CreateMessageFromTypeUrl(type_url, self.descriptor_pool) + sub_message.ParseFromString(message.value) + message_descriptor = sub_message.DESCRIPTOR + full_name = message_descriptor.full_name + if _IsWrapperMessage(message_descriptor): + js['value'] = self._WrapperMessageToJsonObject(sub_message) + return js + if full_name in _WKTJSONMETHODS: + js['value'] = methodcaller(_WKTJSONMETHODS[full_name][0], + sub_message)(self) + return js + return self._RegularMessageToJsonObject(sub_message, js) + + def _GenericMessageToJsonObject(self, message): + """Converts message according to Proto3 JSON Specification.""" + # Duration, Timestamp and FieldMask have ToJsonString method to do the + # convert. Users can also call the method directly. + return message.ToJsonString() + + def _ValueMessageToJsonObject(self, message): + """Converts Value message according to Proto3 JSON Specification.""" + which = message.WhichOneof('kind') + # If the Value message is not set treat as null_value when serialize + # to JSON. The parse back result will be different from original message. + if which is None or which == 'null_value': + return None + if which == 'list_value': + return self._ListValueMessageToJsonObject(message.list_value) + if which == 'struct_value': + value = message.struct_value + else: + value = getattr(message, which) + oneof_descriptor = message.DESCRIPTOR.fields_by_name[which] + return self._FieldToJsonObject(oneof_descriptor, value) + + def _ListValueMessageToJsonObject(self, message): + """Converts ListValue message according to Proto3 JSON Specification.""" + return [self._ValueMessageToJsonObject(value) + for value in message.values] + + def _StructMessageToJsonObject(self, message): + """Converts Struct message according to Proto3 JSON Specification.""" + fields = message.fields + ret = {} + for key in fields: + ret[key] = self._ValueMessageToJsonObject(fields[key]) + return ret + + def _WrapperMessageToJsonObject(self, message): + return self._FieldToJsonObject( + message.DESCRIPTOR.fields_by_name['value'], message.value) + + +def _IsWrapperMessage(message_descriptor): + return message_descriptor.file.name == 'google/protobuf/wrappers.proto' + + +def _DuplicateChecker(js): + result = {} + for name, value in js: + if name in result: + raise ParseError('Failed to load JSON: duplicate key {0}.'.format(name)) + result[name] = value + return result + + +def _CreateMessageFromTypeUrl(type_url, descriptor_pool): + """Creates a message from a type URL.""" + db = symbol_database.Default() + pool = db.pool if descriptor_pool is None else descriptor_pool + type_name = type_url.split('/')[-1] + try: + message_descriptor = pool.FindMessageTypeByName(type_name) + except KeyError: + raise TypeError( + 'Can not find message descriptor by type_url: {0}.'.format(type_url)) + message_class = db.GetPrototype(message_descriptor) + return message_class() + + +def Parse(text, message, ignore_unknown_fields=False, descriptor_pool=None): + """Parses a JSON representation of a protocol message into a message. + + Args: + text: Message JSON representation. + message: A protocol buffer message to merge into. + ignore_unknown_fields: If True, do not raise errors for unknown fields. + descriptor_pool: A Descriptor Pool for resolving types. If None use the + default. + + Returns: + The same message passed as argument. + + Raises:: + ParseError: On JSON parsing problems. + """ + if not isinstance(text, six.text_type): text = text.decode('utf-8') + try: + js = json.loads(text, object_pairs_hook=_DuplicateChecker) + except ValueError as e: + raise ParseError('Failed to load JSON: {0}.'.format(str(e))) + return ParseDict(js, message, ignore_unknown_fields, descriptor_pool) + + +def ParseDict(js_dict, + message, + ignore_unknown_fields=False, + descriptor_pool=None): + """Parses a JSON dictionary representation into a message. + + Args: + js_dict: Dict representation of a JSON message. + message: A protocol buffer message to merge into. + ignore_unknown_fields: If True, do not raise errors for unknown fields. + descriptor_pool: A Descriptor Pool for resolving types. If None use the + default. + + Returns: + The same message passed as argument. + """ + parser = _Parser(ignore_unknown_fields, descriptor_pool) + parser.ConvertMessage(js_dict, message) + return message + + +_INT_OR_FLOAT = six.integer_types + (float,) + + +class _Parser(object): + """JSON format parser for protocol message.""" + + def __init__(self, ignore_unknown_fields, descriptor_pool): + self.ignore_unknown_fields = ignore_unknown_fields + self.descriptor_pool = descriptor_pool + + def ConvertMessage(self, value, message): + """Convert a JSON object into a message. + + Args: + value: A JSON object. + message: A WKT or regular protocol message to record the data. + + Raises: + ParseError: In case of convert problems. + """ + message_descriptor = message.DESCRIPTOR + full_name = message_descriptor.full_name + if _IsWrapperMessage(message_descriptor): + self._ConvertWrapperMessage(value, message) + elif full_name in _WKTJSONMETHODS: + methodcaller(_WKTJSONMETHODS[full_name][1], value, message)(self) + else: + self._ConvertFieldValuePair(value, message) + + def _ConvertFieldValuePair(self, js, message): + """Convert field value pairs into regular message. + + Args: + js: A JSON object to convert the field value pairs. + message: A regular protocol message to record the data. + + Raises: + ParseError: In case of problems converting. + """ + names = [] + message_descriptor = message.DESCRIPTOR + fields_by_json_name = dict((f.json_name, f) + for f in message_descriptor.fields) + for name in js: + try: + field = fields_by_json_name.get(name, None) + if not field: + field = message_descriptor.fields_by_name.get(name, None) + if not field and _VALID_EXTENSION_NAME.match(name): + if not message_descriptor.is_extendable: + raise ParseError('Message type {0} does not have extensions'.format( + message_descriptor.full_name)) + identifier = name[1:-1] # strip [] brackets + # pylint: disable=protected-access + field = message.Extensions._FindExtensionByName(identifier) + # pylint: enable=protected-access + if not field: + # Try looking for extension by the message type name, dropping the + # field name following the final . separator in full_name. + identifier = '.'.join(identifier.split('.')[:-1]) + # pylint: disable=protected-access + field = message.Extensions._FindExtensionByName(identifier) + # pylint: enable=protected-access + if not field: + if self.ignore_unknown_fields: + continue + raise ParseError( + ('Message type "{0}" has no field named "{1}".\n' + ' Available Fields(except extensions): {2}').format( + message_descriptor.full_name, name, + [f.json_name for f in message_descriptor.fields])) + if name in names: + raise ParseError('Message type "{0}" should not have multiple ' + '"{1}" fields.'.format( + message.DESCRIPTOR.full_name, name)) + names.append(name) + value = js[name] + # Check no other oneof field is parsed. + if field.containing_oneof is not None and value is not None: + oneof_name = field.containing_oneof.name + if oneof_name in names: + raise ParseError('Message type "{0}" should not have multiple ' + '"{1}" oneof fields.'.format( + message.DESCRIPTOR.full_name, oneof_name)) + names.append(oneof_name) + + if value is None: + if (field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE + and field.message_type.full_name == 'google.protobuf.Value'): + sub_message = getattr(message, field.name) + sub_message.null_value = 0 + elif (field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM + and field.enum_type.full_name == 'google.protobuf.NullValue'): + setattr(message, field.name, 0) + else: + message.ClearField(field.name) + continue + + # Parse field value. + if _IsMapEntry(field): + message.ClearField(field.name) + self._ConvertMapFieldValue(value, message, field) + elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + message.ClearField(field.name) + if not isinstance(value, list): + raise ParseError('repeated field {0} must be in [] which is ' + '{1}.'.format(name, value)) + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + # Repeated message field. + for item in value: + sub_message = getattr(message, field.name).add() + # None is a null_value in Value. + if (item is None and + sub_message.DESCRIPTOR.full_name != 'google.protobuf.Value'): + raise ParseError('null is not allowed to be used as an element' + ' in a repeated field.') + self.ConvertMessage(item, sub_message) + else: + # Repeated scalar field. + for item in value: + if item is None: + raise ParseError('null is not allowed to be used as an element' + ' in a repeated field.') + getattr(message, field.name).append( + _ConvertScalarFieldValue(item, field)) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + if field.is_extension: + sub_message = message.Extensions[field] + else: + sub_message = getattr(message, field.name) + sub_message.SetInParent() + self.ConvertMessage(value, sub_message) + else: + if field.is_extension: + message.Extensions[field] = _ConvertScalarFieldValue(value, field) + else: + setattr(message, field.name, _ConvertScalarFieldValue(value, field)) + except ParseError as e: + if field and field.containing_oneof is None: + raise ParseError('Failed to parse {0} field: {1}.'.format(name, e)) + else: + raise ParseError(str(e)) + except ValueError as e: + raise ParseError('Failed to parse {0} field: {1}.'.format(name, e)) + except TypeError as e: + raise ParseError('Failed to parse {0} field: {1}.'.format(name, e)) + + def _ConvertAnyMessage(self, value, message): + """Convert a JSON representation into Any message.""" + if isinstance(value, dict) and not value: + return + try: + type_url = value['@type'] + except KeyError: + raise ParseError('@type is missing when parsing any message.') + + sub_message = _CreateMessageFromTypeUrl(type_url, self.descriptor_pool) + message_descriptor = sub_message.DESCRIPTOR + full_name = message_descriptor.full_name + if _IsWrapperMessage(message_descriptor): + self._ConvertWrapperMessage(value['value'], sub_message) + elif full_name in _WKTJSONMETHODS: + methodcaller( + _WKTJSONMETHODS[full_name][1], value['value'], sub_message)(self) + else: + del value['@type'] + self._ConvertFieldValuePair(value, sub_message) + value['@type'] = type_url + # Sets Any message + message.value = sub_message.SerializeToString() + message.type_url = type_url + + def _ConvertGenericMessage(self, value, message): + """Convert a JSON representation into message with FromJsonString.""" + # Duration, Timestamp, FieldMask have a FromJsonString method to do the + # conversion. Users can also call the method directly. + try: + message.FromJsonString(value) + except ValueError as e: + raise ParseError(e) + + def _ConvertValueMessage(self, value, message): + """Convert a JSON representation into Value message.""" + if isinstance(value, dict): + self._ConvertStructMessage(value, message.struct_value) + elif isinstance(value, list): + self. _ConvertListValueMessage(value, message.list_value) + elif value is None: + message.null_value = 0 + elif isinstance(value, bool): + message.bool_value = value + elif isinstance(value, six.string_types): + message.string_value = value + elif isinstance(value, _INT_OR_FLOAT): + message.number_value = value + else: + raise ParseError('Value {0} has unexpected type {1}.'.format( + value, type(value))) + + def _ConvertListValueMessage(self, value, message): + """Convert a JSON representation into ListValue message.""" + if not isinstance(value, list): + raise ParseError( + 'ListValue must be in [] which is {0}.'.format(value)) + message.ClearField('values') + for item in value: + self._ConvertValueMessage(item, message.values.add()) + + def _ConvertStructMessage(self, value, message): + """Convert a JSON representation into Struct message.""" + if not isinstance(value, dict): + raise ParseError( + 'Struct must be in a dict which is {0}.'.format(value)) + # Clear will mark the struct as modified so it will be created even if + # there are no values. + message.Clear() + for key in value: + self._ConvertValueMessage(value[key], message.fields[key]) + return + + def _ConvertWrapperMessage(self, value, message): + """Convert a JSON representation into Wrapper message.""" + field = message.DESCRIPTOR.fields_by_name['value'] + setattr(message, 'value', _ConvertScalarFieldValue(value, field)) + + def _ConvertMapFieldValue(self, value, message, field): + """Convert map field value for a message map field. + + Args: + value: A JSON object to convert the map field value. + message: A protocol message to record the converted data. + field: The descriptor of the map field to be converted. + + Raises: + ParseError: In case of convert problems. + """ + if not isinstance(value, dict): + raise ParseError( + 'Map field {0} must be in a dict which is {1}.'.format( + field.name, value)) + key_field = field.message_type.fields_by_name['key'] + value_field = field.message_type.fields_by_name['value'] + for key in value: + key_value = _ConvertScalarFieldValue(key, key_field, True) + if value_field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + self.ConvertMessage(value[key], getattr( + message, field.name)[key_value]) + else: + getattr(message, field.name)[key_value] = _ConvertScalarFieldValue( + value[key], value_field) + + +def _ConvertScalarFieldValue(value, field, require_str=False): + """Convert a single scalar field value. + + Args: + value: A scalar value to convert the scalar field value. + field: The descriptor of the field to convert. + require_str: If True, the field value must be a str. + + Returns: + The converted scalar field value + + Raises: + ParseError: In case of convert problems. + """ + if field.cpp_type in _INT_TYPES: + return _ConvertInteger(value) + elif field.cpp_type in _FLOAT_TYPES: + return _ConvertFloat(value, field) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: + return _ConvertBool(value, require_str) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: + if field.type == descriptor.FieldDescriptor.TYPE_BYTES: + if isinstance(value, six.text_type): + encoded = value.encode('utf-8') + else: + encoded = value + # Add extra padding '=' + padded_value = encoded + b'=' * (4 - len(encoded) % 4) + return base64.urlsafe_b64decode(padded_value) + else: + # Checking for unpaired surrogates appears to be unreliable, + # depending on the specific Python version, so we check manually. + if _UNPAIRED_SURROGATE_PATTERN.search(value): + raise ParseError('Unpaired surrogate') + return value + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: + # Convert an enum value. + enum_value = field.enum_type.values_by_name.get(value, None) + if enum_value is None: + try: + number = int(value) + enum_value = field.enum_type.values_by_number.get(number, None) + except ValueError: + raise ParseError('Invalid enum value {0} for enum type {1}.'.format( + value, field.enum_type.full_name)) + if enum_value is None: + if field.file.syntax == 'proto3': + # Proto3 accepts unknown enums. + return number + raise ParseError('Invalid enum value {0} for enum type {1}.'.format( + value, field.enum_type.full_name)) + return enum_value.number + + +def _ConvertInteger(value): + """Convert an integer. + + Args: + value: A scalar value to convert. + + Returns: + The integer value. + + Raises: + ParseError: If an integer couldn't be consumed. + """ + if isinstance(value, float) and not value.is_integer(): + raise ParseError('Couldn\'t parse integer: {0}.'.format(value)) + + if isinstance(value, six.text_type) and value.find(' ') != -1: + raise ParseError('Couldn\'t parse integer: "{0}".'.format(value)) + + if isinstance(value, bool): + raise ParseError('Bool value {0} is not acceptable for ' + 'integer field.'.format(value)) + + return int(value) + + +def _ConvertFloat(value, field): + """Convert an floating point number.""" + if isinstance(value, float): + if math.isnan(value): + raise ParseError('Couldn\'t parse NaN, use quoted "NaN" instead.') + if math.isinf(value): + if value > 0: + raise ParseError('Couldn\'t parse Infinity or value too large, ' + 'use quoted "Infinity" instead.') + else: + raise ParseError('Couldn\'t parse -Infinity or value too small, ' + 'use quoted "-Infinity" instead.') + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_FLOAT: + # pylint: disable=protected-access + if value > type_checkers._FLOAT_MAX: + raise ParseError('Float value too large') + # pylint: disable=protected-access + if value < type_checkers._FLOAT_MIN: + raise ParseError('Float value too small') + if value == 'nan': + raise ParseError('Couldn\'t parse float "nan", use "NaN" instead.') + try: + # Assume Python compatible syntax. + return float(value) + except ValueError: + # Check alternative spellings. + if value == _NEG_INFINITY: + return float('-inf') + elif value == _INFINITY: + return float('inf') + elif value == _NAN: + return float('nan') + else: + raise ParseError('Couldn\'t parse float: {0}.'.format(value)) + + +def _ConvertBool(value, require_str): + """Convert a boolean value. + + Args: + value: A scalar value to convert. + require_str: If True, value must be a str. + + Returns: + The bool parsed. + + Raises: + ParseError: If a boolean value couldn't be consumed. + """ + if require_str: + if value == 'true': + return True + elif value == 'false': + return False + else: + raise ParseError('Expected "true" or "false", not {0}.'.format(value)) + + if not isinstance(value, bool): + raise ParseError('Expected true or false without quotes.') + return value + +_WKTJSONMETHODS = { + 'google.protobuf.Any': ['_AnyMessageToJsonObject', + '_ConvertAnyMessage'], + 'google.protobuf.Duration': ['_GenericMessageToJsonObject', + '_ConvertGenericMessage'], + 'google.protobuf.FieldMask': ['_GenericMessageToJsonObject', + '_ConvertGenericMessage'], + 'google.protobuf.ListValue': ['_ListValueMessageToJsonObject', + '_ConvertListValueMessage'], + 'google.protobuf.Struct': ['_StructMessageToJsonObject', + '_ConvertStructMessage'], + 'google.protobuf.Timestamp': ['_GenericMessageToJsonObject', + '_ConvertGenericMessage'], + 'google.protobuf.Value': ['_ValueMessageToJsonObject', + '_ConvertValueMessage'] +} diff --git a/contrib/python/protobuf/py3/google/protobuf/message.py b/contrib/python/protobuf/py3/google/protobuf/message.py new file mode 100644 index 0000000000..224d2fc491 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/message.py @@ -0,0 +1,413 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# TODO(robinson): We should just make these methods all "pure-virtual" and move +# all implementation out, into reflection.py for now. + + +"""Contains an abstract base class for protocol messages.""" + +__author__ = 'robinson@google.com (Will Robinson)' + +class Error(Exception): + """Base error type for this module.""" + pass + + +class DecodeError(Error): + """Exception raised when deserializing messages.""" + pass + + +class EncodeError(Error): + """Exception raised when serializing messages.""" + pass + + +class Message(object): + + """Abstract base class for protocol messages. + + Protocol message classes are almost always generated by the protocol + compiler. These generated types subclass Message and implement the methods + shown below. + """ + + # TODO(robinson): Link to an HTML document here. + + # TODO(robinson): Document that instances of this class will also + # have an Extensions attribute with __getitem__ and __setitem__. + # Again, not sure how to best convey this. + + # TODO(robinson): Document that the class must also have a static + # RegisterExtension(extension_field) method. + # Not sure how to best express at this point. + + # TODO(robinson): Document these fields and methods. + + __slots__ = [] + + #: The :class:`google.protobuf.descriptor.Descriptor` for this message type. + DESCRIPTOR = None + + def __deepcopy__(self, memo=None): + clone = type(self)() + clone.MergeFrom(self) + return clone + + def __eq__(self, other_msg): + """Recursively compares two messages by value and structure.""" + raise NotImplementedError + + def __ne__(self, other_msg): + # Can't just say self != other_msg, since that would infinitely recurse. :) + return not self == other_msg + + def __hash__(self): + raise TypeError('unhashable object') + + def __str__(self): + """Outputs a human-readable representation of the message.""" + raise NotImplementedError + + def __unicode__(self): + """Outputs a human-readable representation of the message.""" + raise NotImplementedError + + def MergeFrom(self, other_msg): + """Merges the contents of the specified message into current message. + + This method merges the contents of the specified message into the current + message. Singular fields that are set in the specified message overwrite + the corresponding fields in the current message. Repeated fields are + appended. Singular sub-messages and groups are recursively merged. + + Args: + other_msg (Message): A message to merge into the current message. + """ + raise NotImplementedError + + def CopyFrom(self, other_msg): + """Copies the content of the specified message into the current message. + + The method clears the current message and then merges the specified + message using MergeFrom. + + Args: + other_msg (Message): A message to copy into the current one. + """ + if self is other_msg: + return + self.Clear() + self.MergeFrom(other_msg) + + def Clear(self): + """Clears all data that was set in the message.""" + raise NotImplementedError + + def SetInParent(self): + """Mark this as present in the parent. + + This normally happens automatically when you assign a field of a + sub-message, but sometimes you want to make the sub-message + present while keeping it empty. If you find yourself using this, + you may want to reconsider your design. + """ + raise NotImplementedError + + def IsInitialized(self): + """Checks if the message is initialized. + + Returns: + bool: The method returns True if the message is initialized (i.e. all of + its required fields are set). + """ + raise NotImplementedError + + # TODO(robinson): MergeFromString() should probably return None and be + # implemented in terms of a helper that returns the # of bytes read. Our + # deserialization routines would use the helper when recursively + # deserializing, but the end user would almost always just want the no-return + # MergeFromString(). + + def MergeFromString(self, serialized): + """Merges serialized protocol buffer data into this message. + + When we find a field in `serialized` that is already present + in this message: + + - If it's a "repeated" field, we append to the end of our list. + - Else, if it's a scalar, we overwrite our field. + - Else, (it's a nonrepeated composite), we recursively merge + into the existing composite. + + Args: + serialized (bytes): Any object that allows us to call + ``memoryview(serialized)`` to access a string of bytes using the + buffer interface. + + Returns: + int: The number of bytes read from `serialized`. + For non-group messages, this will always be `len(serialized)`, + but for messages which are actually groups, this will + generally be less than `len(serialized)`, since we must + stop when we reach an ``END_GROUP`` tag. Note that if + we *do* stop because of an ``END_GROUP`` tag, the number + of bytes returned does not include the bytes + for the ``END_GROUP`` tag information. + + Raises: + DecodeError: if the input cannot be parsed. + """ + # TODO(robinson): Document handling of unknown fields. + # TODO(robinson): When we switch to a helper, this will return None. + raise NotImplementedError + + def ParseFromString(self, serialized): + """Parse serialized protocol buffer data into this message. + + Like :func:`MergeFromString()`, except we clear the object first. + """ + self.Clear() + return self.MergeFromString(serialized) + + def SerializeToString(self, **kwargs): + """Serializes the protocol message to a binary string. + + Keyword Args: + deterministic (bool): If true, requests deterministic serialization + of the protobuf, with predictable ordering of map keys. + + Returns: + A binary string representation of the message if all of the required + fields in the message are set (i.e. the message is initialized). + + Raises: + EncodeError: if the message isn't initialized (see :func:`IsInitialized`). + """ + raise NotImplementedError + + def SerializePartialToString(self, **kwargs): + """Serializes the protocol message to a binary string. + + This method is similar to SerializeToString but doesn't check if the + message is initialized. + + Keyword Args: + deterministic (bool): If true, requests deterministic serialization + of the protobuf, with predictable ordering of map keys. + + Returns: + bytes: A serialized representation of the partial message. + """ + raise NotImplementedError + + # TODO(robinson): Decide whether we like these better + # than auto-generated has_foo() and clear_foo() methods + # on the instances themselves. This way is less consistent + # with C++, but it makes reflection-type access easier and + # reduces the number of magically autogenerated things. + # + # TODO(robinson): Be sure to document (and test) exactly + # which field names are accepted here. Are we case-sensitive? + # What do we do with fields that share names with Python keywords + # like 'lambda' and 'yield'? + # + # nnorwitz says: + # """ + # Typically (in python), an underscore is appended to names that are + # keywords. So they would become lambda_ or yield_. + # """ + def ListFields(self): + """Returns a list of (FieldDescriptor, value) tuples for present fields. + + A message field is non-empty if HasField() would return true. A singular + primitive field is non-empty if HasField() would return true in proto2 or it + is non zero in proto3. A repeated field is non-empty if it contains at least + one element. The fields are ordered by field number. + + Returns: + list[tuple(FieldDescriptor, value)]: field descriptors and values + for all fields in the message which are not empty. The values vary by + field type. + """ + raise NotImplementedError + + def HasField(self, field_name): + """Checks if a certain field is set for the message. + + For a oneof group, checks if any field inside is set. Note that if the + field_name is not defined in the message descriptor, :exc:`ValueError` will + be raised. + + Args: + field_name (str): The name of the field to check for presence. + + Returns: + bool: Whether a value has been set for the named field. + + Raises: + ValueError: if the `field_name` is not a member of this message. + """ + raise NotImplementedError + + def ClearField(self, field_name): + """Clears the contents of a given field. + + Inside a oneof group, clears the field set. If the name neither refers to a + defined field or oneof group, :exc:`ValueError` is raised. + + Args: + field_name (str): The name of the field to check for presence. + + Raises: + ValueError: if the `field_name` is not a member of this message. + """ + raise NotImplementedError + + def WhichOneof(self, oneof_group): + """Returns the name of the field that is set inside a oneof group. + + If no field is set, returns None. + + Args: + oneof_group (str): the name of the oneof group to check. + + Returns: + str or None: The name of the group that is set, or None. + + Raises: + ValueError: no group with the given name exists + """ + raise NotImplementedError + + def HasExtension(self, extension_handle): + """Checks if a certain extension is present for this message. + + Extensions are retrieved using the :attr:`Extensions` mapping (if present). + + Args: + extension_handle: The handle for the extension to check. + + Returns: + bool: Whether the extension is present for this message. + + Raises: + KeyError: if the extension is repeated. Similar to repeated fields, + there is no separate notion of presence: a "not present" repeated + extension is an empty list. + """ + raise NotImplementedError + + def ClearExtension(self, extension_handle): + """Clears the contents of a given extension. + + Args: + extension_handle: The handle for the extension to clear. + """ + raise NotImplementedError + + def UnknownFields(self): + """Returns the UnknownFieldSet. + + Returns: + UnknownFieldSet: The unknown fields stored in this message. + """ + raise NotImplementedError + + def DiscardUnknownFields(self): + """Clears all fields in the :class:`UnknownFieldSet`. + + This operation is recursive for nested message. + """ + raise NotImplementedError + + def ByteSize(self): + """Returns the serialized size of this message. + + Recursively calls ByteSize() on all contained messages. + + Returns: + int: The number of bytes required to serialize this message. + """ + raise NotImplementedError + + def _SetListener(self, message_listener): + """Internal method used by the protocol message implementation. + Clients should not call this directly. + + Sets a listener that this message will call on certain state transitions. + + The purpose of this method is to register back-edges from children to + parents at runtime, for the purpose of setting "has" bits and + byte-size-dirty bits in the parent and ancestor objects whenever a child or + descendant object is modified. + + If the client wants to disconnect this Message from the object tree, she + explicitly sets callback to None. + + If message_listener is None, unregisters any existing listener. Otherwise, + message_listener must implement the MessageListener interface in + internal/message_listener.py, and we discard any listener registered + via a previous _SetListener() call. + """ + raise NotImplementedError + + def __getstate__(self): + """Support the pickle protocol.""" + return dict(serialized=self.SerializePartialToString()) + + def __setstate__(self, state): + """Support the pickle protocol.""" + self.__init__() + serialized = state['serialized'] + # On Python 3, using encoding='latin1' is required for unpickling + # protos pickled by Python 2. + if not isinstance(serialized, bytes): + serialized = serialized.encode('latin1') + self.ParseFromString(serialized) + + def __reduce__(self): + message_descriptor = self.DESCRIPTOR + if message_descriptor.containing_type is None: + return type(self), (), self.__getstate__() + # the message type must be nested. + # Python does not pickle nested classes; use the symbol_database on the + # receiving end. + container = message_descriptor + return (_InternalConstructMessage, (container.full_name,), + self.__getstate__()) + + +def _InternalConstructMessage(full_name): + """Constructs a nested message.""" + from google.protobuf import symbol_database # pylint:disable=g-import-not-at-top + + return symbol_database.Default().GetSymbol(full_name)() diff --git a/contrib/python/protobuf/py3/google/protobuf/message_factory.py b/contrib/python/protobuf/py3/google/protobuf/message_factory.py new file mode 100644 index 0000000000..7dfaec88e1 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/message_factory.py @@ -0,0 +1,187 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Provides a factory class for generating dynamic messages. + +The easiest way to use this class is if you have access to the FileDescriptor +protos containing the messages you want to create you can just do the following: + +message_classes = message_factory.GetMessages(iterable_of_file_descriptors) +my_proto_instance = message_classes['some.proto.package.MessageName']() +""" + +__author__ = 'matthewtoia@google.com (Matt Toia)' + +from google.protobuf.internal import api_implementation +from google.protobuf import descriptor_pool +from google.protobuf import message + +if api_implementation.Type() == 'cpp': + from google.protobuf.pyext import cpp_message as message_impl +else: + from google.protobuf.internal import python_message as message_impl + + +# The type of all Message classes. +_GENERATED_PROTOCOL_MESSAGE_TYPE = message_impl.GeneratedProtocolMessageType + + +class MessageFactory(object): + """Factory for creating Proto2 messages from descriptors in a pool.""" + + def __init__(self, pool=None): + """Initializes a new factory.""" + self.pool = pool or descriptor_pool.DescriptorPool() + + # local cache of all classes built from protobuf descriptors + self._classes = {} + + def GetPrototype(self, descriptor): + """Obtains a proto2 message class based on the passed in descriptor. + + Passing a descriptor with a fully qualified name matching a previous + invocation will cause the same class to be returned. + + Args: + descriptor: The descriptor to build from. + + Returns: + A class describing the passed in descriptor. + """ + if descriptor not in self._classes: + result_class = self.CreatePrototype(descriptor) + # The assignment to _classes is redundant for the base implementation, but + # might avoid confusion in cases where CreatePrototype gets overridden and + # does not call the base implementation. + self._classes[descriptor] = result_class + return result_class + return self._classes[descriptor] + + def CreatePrototype(self, descriptor): + """Builds a proto2 message class based on the passed in descriptor. + + Don't call this function directly, it always creates a new class. Call + GetPrototype() instead. This method is meant to be overridden in subblasses + to perform additional operations on the newly constructed class. + + Args: + descriptor: The descriptor to build from. + + Returns: + A class describing the passed in descriptor. + """ + descriptor_name = descriptor.name + if str is bytes: # PY2 + descriptor_name = descriptor.name.encode('ascii', 'ignore') + result_class = _GENERATED_PROTOCOL_MESSAGE_TYPE( + descriptor_name, + (message.Message,), + { + 'DESCRIPTOR': descriptor, + # If module not set, it wrongly points to message_factory module. + '__module__': None, + }) + result_class._FACTORY = self # pylint: disable=protected-access + # Assign in _classes before doing recursive calls to avoid infinite + # recursion. + self._classes[descriptor] = result_class + for field in descriptor.fields: + if field.message_type: + self.GetPrototype(field.message_type) + for extension in result_class.DESCRIPTOR.extensions: + if extension.containing_type not in self._classes: + self.GetPrototype(extension.containing_type) + extended_class = self._classes[extension.containing_type] + extended_class.RegisterExtension(extension) + return result_class + + def GetMessages(self, files): + """Gets all the messages from a specified file. + + This will find and resolve dependencies, failing if the descriptor + pool cannot satisfy them. + + Args: + files: The file names to extract messages from. + + Returns: + A dictionary mapping proto names to the message classes. This will include + any dependent messages as well as any messages defined in the same file as + a specified message. + """ + result = {} + for file_name in files: + file_desc = self.pool.FindFileByName(file_name) + for desc in file_desc.message_types_by_name.values(): + result[desc.full_name] = self.GetPrototype(desc) + + # While the extension FieldDescriptors are created by the descriptor pool, + # the python classes created in the factory need them to be registered + # explicitly, which is done below. + # + # The call to RegisterExtension will specifically check if the + # extension was already registered on the object and either + # ignore the registration if the original was the same, or raise + # an error if they were different. + + for extension in file_desc.extensions_by_name.values(): + if extension.containing_type not in self._classes: + self.GetPrototype(extension.containing_type) + extended_class = self._classes[extension.containing_type] + extended_class.RegisterExtension(extension) + return result + + +_FACTORY = MessageFactory() + + +def GetMessages(file_protos): + """Builds a dictionary of all the messages available in a set of files. + + Args: + file_protos: Iterable of FileDescriptorProto to build messages out of. + + Returns: + A dictionary mapping proto names to the message classes. This will include + any dependent messages as well as any messages defined in the same file as + a specified message. + """ + # The cpp implementation of the protocol buffer library requires to add the + # message in topological order of the dependency graph. + file_by_name = {file_proto.name: file_proto for file_proto in file_protos} + def _AddFile(file_proto): + for dependency in file_proto.dependency: + if dependency in file_by_name: + # Remove from elements to be visited, in order to cut cycles. + _AddFile(file_by_name.pop(dependency)) + _FACTORY.pool.Add(file_proto) + while file_by_name: + _AddFile(file_by_name.popitem()[1]) + return _FACTORY.GetMessages([file_proto.name for file_proto in file_protos]) diff --git a/contrib/python/protobuf/py3/google/protobuf/proto_api.h b/contrib/python/protobuf/py3/google/protobuf/proto_api.h new file mode 100644 index 0000000000..c869bce058 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/proto_api.h @@ -0,0 +1,123 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// This file can be included by other C++ libraries, typically extension modules +// which want to interact with the Python Messages coming from the "cpp" +// implementation of protocol buffers. +// +// Usage: +// Declare a (probably static) variable to hold the API: +// const PyProto_API* py_proto_api; +// In some initialization function, write: +// py_proto_api = static_cast<const PyProto_API*>(PyCapsule_Import( +// PyProtoAPICapsuleName(), 0)); +// if (!py_proto_api) { ...handle ImportError... } +// Then use the methods of the returned class: +// py_proto_api->GetMessagePointer(...); + +#ifndef GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__ +#define GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__ + +#include <Python.h> + +#include <google/protobuf/descriptor_database.h> +#include <google/protobuf/message.h> + +namespace google { +namespace protobuf { +namespace python { + +// Note on the implementation: +// This API is designed after +// https://docs.python.org/3/extending/extending.html#providing-a-c-api-for-an-extension-module +// The class below contains no mutable state, and all methods are "const"; +// we use a C++ class instead of a C struct with functions pointers just because +// the code looks more readable. +struct PyProto_API { + // The API object is created at initialization time and never freed. + // This destructor is never called. + virtual ~PyProto_API() {} + + // Operations on Messages. + + // If the passed object is a Python Message, returns its internal pointer. + // Otherwise, returns NULL with an exception set. + virtual const Message* GetMessagePointer(PyObject* msg) const = 0; + + // If the passed object is a Python Message, returns a mutable pointer. + // Otherwise, returns NULL with an exception set. + // This function will succeed only if there are no other Python objects + // pointing to the message, like submessages or repeated containers. + // With the current implementation, only empty messages are in this case. + virtual Message* GetMutableMessagePointer(PyObject* msg) const = 0; + + // Expose the underlying DescriptorPool and MessageFactory to enable C++ code + // to create Python-compatible message. + virtual const DescriptorPool* GetDefaultDescriptorPool() const = 0; + virtual MessageFactory* GetDefaultMessageFactory() const = 0; + + // Allocate a new protocol buffer as a python object for the provided + // descriptor. This function works even if no Python module has been imported + // for the corresponding protocol buffer class. + // The factory is usually null; when provided, it is the MessageFactory which + // owns the Python class, and will be used to find and create Extensions for + // this message. + // When null is returned, a python error has already been set. + virtual PyObject* NewMessage(const Descriptor* descriptor, + PyObject* py_message_factory) const = 0; + + // Allocate a new protocol buffer where the underlying object is owned by C++. + // The factory must currently be null. This function works even if no Python + // module has been imported for the corresponding protocol buffer class. + // When null is returned, a python error has already been set. + // + // Since this call returns a python object owned by C++, some operations + // are risky, and it must be used carefully. In particular: + // * Avoid modifying the returned object from the C++ side while there are + // existing python references to it or it's subobjects. + // * Avoid using python references to this object or any subobjects after the + // C++ object has been freed. + // * Calling this with the same C++ pointer will result in multiple distinct + // python objects referencing the same C++ object. + virtual PyObject* NewMessageOwnedExternally( + Message* msg, PyObject* py_message_factory) const = 0; +}; + +inline const char* PyProtoAPICapsuleName() { + static const char kCapsuleName[] = + "google.protobuf.pyext._message.proto_API"; + return kCapsuleName; +} + +} // namespace python +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__ diff --git a/contrib/python/protobuf/py3/google/protobuf/proto_builder.py b/contrib/python/protobuf/py3/google/protobuf/proto_builder.py new file mode 100644 index 0000000000..2b7dddcbd3 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/proto_builder.py @@ -0,0 +1,137 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Dynamic Protobuf class creator.""" + +try: + from collections import OrderedDict +except ImportError: + from ordereddict import OrderedDict #PY26 +import hashlib +import os + +from google.protobuf import descriptor_pb2 +from google.protobuf import descriptor +from google.protobuf import message_factory + + +def _GetMessageFromFactory(factory, full_name): + """Get a proto class from the MessageFactory by name. + + Args: + factory: a MessageFactory instance. + full_name: str, the fully qualified name of the proto type. + Returns: + A class, for the type identified by full_name. + Raises: + KeyError, if the proto is not found in the factory's descriptor pool. + """ + proto_descriptor = factory.pool.FindMessageTypeByName(full_name) + proto_cls = factory.GetPrototype(proto_descriptor) + return proto_cls + + +def MakeSimpleProtoClass(fields, full_name=None, pool=None): + """Create a Protobuf class whose fields are basic types. + + Note: this doesn't validate field names! + + Args: + fields: dict of {name: field_type} mappings for each field in the proto. If + this is an OrderedDict the order will be maintained, otherwise the + fields will be sorted by name. + full_name: optional str, the fully-qualified name of the proto type. + pool: optional DescriptorPool instance. + Returns: + a class, the new protobuf class with a FileDescriptor. + """ + factory = message_factory.MessageFactory(pool=pool) + + if full_name is not None: + try: + proto_cls = _GetMessageFromFactory(factory, full_name) + return proto_cls + except KeyError: + # The factory's DescriptorPool doesn't know about this class yet. + pass + + # Get a list of (name, field_type) tuples from the fields dict. If fields was + # an OrderedDict we keep the order, but otherwise we sort the field to ensure + # consistent ordering. + field_items = fields.items() + if not isinstance(fields, OrderedDict): + field_items = sorted(field_items) + + # Use a consistent file name that is unlikely to conflict with any imported + # proto files. + fields_hash = hashlib.sha1() + for f_name, f_type in field_items: + fields_hash.update(f_name.encode('utf-8')) + fields_hash.update(str(f_type).encode('utf-8')) + proto_file_name = fields_hash.hexdigest() + '.proto' + + # If the proto is anonymous, use the same hash to name it. + if full_name is None: + full_name = ('net.proto2.python.public.proto_builder.AnonymousProto_' + + fields_hash.hexdigest()) + try: + proto_cls = _GetMessageFromFactory(factory, full_name) + return proto_cls + except KeyError: + # The factory's DescriptorPool doesn't know about this class yet. + pass + + # This is the first time we see this proto: add a new descriptor to the pool. + factory.pool.Add( + _MakeFileDescriptorProto(proto_file_name, full_name, field_items)) + return _GetMessageFromFactory(factory, full_name) + + +def _MakeFileDescriptorProto(proto_file_name, full_name, field_items): + """Populate FileDescriptorProto for MessageFactory's DescriptorPool.""" + package, name = full_name.rsplit('.', 1) + file_proto = descriptor_pb2.FileDescriptorProto() + file_proto.name = os.path.join(package.replace('.', '/'), proto_file_name) + file_proto.package = package + desc_proto = file_proto.message_type.add() + desc_proto.name = name + for f_number, (f_name, f_type) in enumerate(field_items, 1): + field_proto = desc_proto.field.add() + field_proto.name = f_name + # # If the number falls in the reserved range, reassign it to the correct + # # number after the range. + if f_number >= descriptor.FieldDescriptor.FIRST_RESERVED_FIELD_NUMBER: + f_number += ( + descriptor.FieldDescriptor.LAST_RESERVED_FIELD_NUMBER - + descriptor.FieldDescriptor.FIRST_RESERVED_FIELD_NUMBER + 1) + field_proto.number = f_number + field_proto.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + field_proto.type = f_type + return file_proto diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/README b/contrib/python/protobuf/py3/google/protobuf/pyext/README new file mode 100644 index 0000000000..6d61cb45bf --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/README @@ -0,0 +1,6 @@ +This is the 'v2' C++ implementation for python proto2. + +It is active when: + +PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp +PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2 diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/__init__.py b/contrib/python/protobuf/py3/google/protobuf/pyext/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/__init__.py diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/cpp_message.py b/contrib/python/protobuf/py3/google/protobuf/pyext/cpp_message.py new file mode 100644 index 0000000000..fc8eb32d79 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/cpp_message.py @@ -0,0 +1,65 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Protocol message implementation hooks for C++ implementation. + +Contains helper functions used to create protocol message classes from +Descriptor objects at runtime backed by the protocol buffer C++ API. +""" + +__author__ = 'tibell@google.com (Johan Tibell)' + +from google.protobuf.pyext import _message + + +class GeneratedProtocolMessageType(_message.MessageMeta): + + """Metaclass for protocol message classes created at runtime from Descriptors. + + The protocol compiler currently uses this metaclass to create protocol + message classes at runtime. Clients can also manually create their own + classes at runtime, as in this example: + + mydescriptor = Descriptor(.....) + factory = symbol_database.Default() + factory.pool.AddDescriptor(mydescriptor) + MyProtoClass = factory.GetPrototype(mydescriptor) + myproto_instance = MyProtoClass() + myproto.foo_field = 23 + ... + + The above example will not work for nested types. If you wish to include them, + use reflection.MakeClass() instead of manually instantiating the class in + order to create the appropriate class structure. + """ + + # Must be consistent with the protocol-compiler code in + # proto2/compiler/internal/generator.*. + _DESCRIPTOR_KEY = 'DESCRIPTOR' diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor.cc b/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor.cc new file mode 100644 index 0000000000..de788afa2f --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor.cc @@ -0,0 +1,1973 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: petar@google.com (Petar Petrov) + +#include <google/protobuf/pyext/descriptor.h> + +#include <Python.h> +#include <frameobject.h> + +#include <cstdint> +#include <string> +#include <unordered_map> + +#include <google/protobuf/io/coded_stream.h> +#include <google/protobuf/descriptor.pb.h> +#include <google/protobuf/dynamic_message.h> +#include <google/protobuf/pyext/descriptor_containers.h> +#include <google/protobuf/pyext/descriptor_pool.h> +#include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/message_factory.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> +#include <google/protobuf/stubs/hash.h> + +#if PY_MAJOR_VERSION >= 3 + #define PyString_FromStringAndSize PyUnicode_FromStringAndSize + #define PyString_Check PyUnicode_Check + #define PyString_InternFromString PyUnicode_InternFromString + #define PyInt_FromLong PyLong_FromLong + #define PyInt_FromSize_t PyLong_FromSize_t + #if PY_VERSION_HEX < 0x03030000 + #error "Python 3.0 - 3.2 are not supported." + #endif +#define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>( \ + PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \ + ? -1 \ + : 0) \ + : PyBytes_AsStringAndSize(ob, (charpp), (sizep))) +#endif + +namespace google { +namespace protobuf { +namespace python { + +// Store interned descriptors, so that the same C++ descriptor yields the same +// Python object. Objects are not immortal: this map does not own the +// references, and items are deleted when the last reference to the object is +// released. +// This is enough to support the "is" operator on live objects. +// All descriptors are stored here. +std::unordered_map<const void*, PyObject*>* interned_descriptors; + +PyObject* PyString_FromCppString(const TProtoStringType& str) { + return PyString_FromStringAndSize(str.c_str(), str.size()); +} + +// Check that the calling Python code is the global scope of a _pb2.py module. +// This function is used to support the current code generated by the proto +// compiler, which creates descriptors, then update some properties. +// For example: +// message_descriptor = Descriptor( +// name='Message', +// fields = [FieldDescriptor(name='field')] +// message_descriptor.fields[0].containing_type = message_descriptor +// +// This code is still executed, but the descriptors now have no other storage +// than the (const) C++ pointer, and are immutable. +// So we let this code pass, by simply ignoring the new value. +// +// From user code, descriptors still look immutable. +// +// TODO(amauryfa): Change the proto2 compiler to remove the assignments, and +// remove this hack. +bool _CalledFromGeneratedFile(int stacklevel) { +#ifndef PYPY_VERSION + // This check is not critical and is somewhat difficult to implement correctly + // in PyPy. + PyFrameObject* frame = PyEval_GetFrame(); + if (frame == NULL) { + return false; + } + while (stacklevel-- > 0) { + frame = frame->f_back; + if (frame == NULL) { + return false; + } + } + + if (frame->f_code->co_filename == NULL) { + return false; + } + char* filename; + Py_ssize_t filename_size; + if (PyString_AsStringAndSize(frame->f_code->co_filename, + &filename, &filename_size) < 0) { + // filename is not a string. + PyErr_Clear(); + return false; + } + if ((filename_size < 3) || + (strcmp(&filename[filename_size - 3], ".py") != 0)) { + // Cython's stack does not have .py file name and is not at global module + // scope. + return true; + } + if (filename_size < 7) { + // filename is too short. + return false; + } + if (strcmp(&filename[filename_size - 7], "_pb2.py") != 0) { + // Filename is not ending with _pb2. + return false; + } + + if (frame->f_globals != frame->f_locals) { + // Not at global module scope + return false; + } +#endif + return true; +} + +// If the calling code is not a _pb2.py file, raise AttributeError. +// To be used in attribute setters. +static int CheckCalledFromGeneratedFile(const char* attr_name) { + if (_CalledFromGeneratedFile(0)) { + return 0; + } + PyErr_Format(PyExc_AttributeError, + "attribute is not writable: %s", attr_name); + return -1; +} + + +#ifndef PyVarObject_HEAD_INIT +#define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size, +#endif +#ifndef Py_TYPE +#define Py_TYPE(ob) (((PyObject*)(ob))->ob_type) +#endif + + +// Helper functions for descriptor objects. + +// A set of templates to retrieve the C++ FileDescriptor of any descriptor. +template<class DescriptorClass> +const FileDescriptor* GetFileDescriptor(const DescriptorClass* descriptor) { + return descriptor->file(); +} +template<> +const FileDescriptor* GetFileDescriptor(const FileDescriptor* descriptor) { + return descriptor; +} +template<> +const FileDescriptor* GetFileDescriptor(const EnumValueDescriptor* descriptor) { + return descriptor->type()->file(); +} +template<> +const FileDescriptor* GetFileDescriptor(const OneofDescriptor* descriptor) { + return descriptor->containing_type()->file(); +} +template<> +const FileDescriptor* GetFileDescriptor(const MethodDescriptor* descriptor) { + return descriptor->service()->file(); +} + +bool Reparse( + PyMessageFactory* message_factory, const Message& from, Message* to) { + // Reparse message. + TProtoStringType serialized; + from.SerializeToString(&serialized); + io::CodedInputStream input( + reinterpret_cast<const uint8_t*>(serialized.c_str()), serialized.size()); + input.SetExtensionRegistry(message_factory->pool->pool, + message_factory->message_factory); + bool success = to->ParseFromCodedStream(&input); + if (!success) { + return false; + } + return true; +} +// Converts options into a Python protobuf, and cache the result. +// +// This is a bit tricky because options can contain extension fields defined in +// the same proto file. In this case the options parsed from the serialized_pb +// have unknown fields, and we need to parse them again. +// +// Always returns a new reference. +template<class DescriptorClass> +static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) { + // Options are cached in the pool that owns the descriptor. + // First search in the cache. + PyDescriptorPool* caching_pool = GetDescriptorPool_FromPool( + GetFileDescriptor(descriptor)->pool()); + std::unordered_map<const void*, PyObject*>* descriptor_options = + caching_pool->descriptor_options; + if (descriptor_options->find(descriptor) != descriptor_options->end()) { + PyObject *value = (*descriptor_options)[descriptor]; + Py_INCREF(value); + return value; + } + + // Similar to the C++ implementation, we return an Options object from the + // default (generated) factory, so that client code know that they can use + // extensions from generated files: + // d.GetOptions().Extensions[some_pb2.extension] + // + // The consequence is that extensions not defined in the default pool won't + // be available. If needed, we could add an optional 'message_factory' + // parameter to the GetOptions() function. + PyMessageFactory* message_factory = + GetDefaultDescriptorPool()->py_message_factory; + + // Build the Options object: get its Python class, and make a copy of the C++ + // read-only instance. + const Message& options(descriptor->options()); + const Descriptor *message_type = options.GetDescriptor(); + CMessageClass* message_class = message_factory::GetOrCreateMessageClass( + message_factory, message_type); + if (message_class == NULL) { + PyErr_Format(PyExc_TypeError, "Could not retrieve class for Options: %s", + message_type->full_name().c_str()); + return NULL; + } + ScopedPyObjectPtr args(PyTuple_New(0)); + ScopedPyObjectPtr value( + PyObject_Call(message_class->AsPyObject(), args.get(), NULL)); + Py_DECREF(message_class); + if (value == NULL) { + return NULL; + } + if (!PyObject_TypeCheck(value.get(), CMessage_Type)) { + PyErr_Format(PyExc_TypeError, "Invalid class for %s: %s", + message_type->full_name().c_str(), + Py_TYPE(value.get())->tp_name); + return NULL; + } + CMessage* cmsg = reinterpret_cast<CMessage*>(value.get()); + + const Reflection* reflection = options.GetReflection(); + const UnknownFieldSet& unknown_fields(reflection->GetUnknownFields(options)); + if (unknown_fields.empty()) { + cmsg->message->CopyFrom(options); + } else { + // Reparse options string! XXX call cmessage::MergeFromString + if (!Reparse(message_factory, options, cmsg->message)) { + PyErr_Format(PyExc_ValueError, "Error reparsing Options message"); + return NULL; + } + } + + // Cache the result. + Py_INCREF(value.get()); + (*descriptor_options)[descriptor] = value.get(); + + return value.release(); +} + +// Copy the C++ descriptor to a Python message. +// The Python message is an instance of descriptor_pb2.DescriptorProto +// or similar. +template<class DescriptorProtoClass, class DescriptorClass> +static PyObject* CopyToPythonProto(const DescriptorClass *descriptor, + PyObject *target) { + const Descriptor* self_descriptor = + DescriptorProtoClass::default_instance().GetDescriptor(); + CMessage* message = reinterpret_cast<CMessage*>(target); + if (!PyObject_TypeCheck(target, CMessage_Type) || + message->message->GetDescriptor() != self_descriptor) { + PyErr_Format(PyExc_TypeError, "Not a %s message", + self_descriptor->full_name().c_str()); + return NULL; + } + cmessage::AssureWritable(message); + DescriptorProtoClass* descriptor_message = + static_cast<DescriptorProtoClass*>(message->message); + descriptor->CopyTo(descriptor_message); + // Custom options might in unknown extensions. Reparse + // the descriptor_message. Can't skip reparse when options unknown + // fields is empty, because they might in sub descriptors' options. + PyMessageFactory* message_factory = + GetDefaultDescriptorPool()->py_message_factory; + if (!Reparse(message_factory, *descriptor_message, descriptor_message)) { + PyErr_Format(PyExc_ValueError, "Error reparsing descriptor message"); + return nullptr; + } + + Py_RETURN_NONE; +} + +// All Descriptors classes share the same memory layout. +typedef struct PyBaseDescriptor { + PyObject_HEAD + + // Pointer to the C++ proto2 descriptor. + // Like all descriptors, it is owned by the global DescriptorPool. + const void* descriptor; + + // Owned reference to the DescriptorPool, to ensure it is kept alive. + PyDescriptorPool* pool; +} PyBaseDescriptor; + + +// FileDescriptor structure "inherits" from the base descriptor. +typedef struct PyFileDescriptor { + PyBaseDescriptor base; + + // The cached version of serialized pb. Either NULL, or a Bytes string. + // We own the reference. + PyObject *serialized_pb; +} PyFileDescriptor; + + +namespace descriptor { + +// Creates or retrieve a Python descriptor of the specified type. +// Objects are interned: the same descriptor will return the same object if it +// was kept alive. +// 'was_created' is an optional pointer to a bool, and is set to true if a new +// object was allocated. +// Always return a new reference. +template<class DescriptorClass> +PyObject* NewInternedDescriptor(PyTypeObject* type, + const DescriptorClass* descriptor, + bool* was_created) { + if (was_created) { + *was_created = false; + } + if (descriptor == NULL) { + PyErr_BadInternalCall(); + return NULL; + } + + // See if the object is in the map of interned descriptors + std::unordered_map<const void*, PyObject*>::iterator it = + interned_descriptors->find(descriptor); + if (it != interned_descriptors->end()) { + GOOGLE_DCHECK(Py_TYPE(it->second) == type); + Py_INCREF(it->second); + return it->second; + } + // Create a new descriptor object + PyBaseDescriptor* py_descriptor = PyObject_GC_New( + PyBaseDescriptor, type); + if (py_descriptor == NULL) { + return NULL; + } + py_descriptor->descriptor = descriptor; + + // and cache it. + interned_descriptors->insert( + std::make_pair(descriptor, reinterpret_cast<PyObject*>(py_descriptor))); + + // Ensures that the DescriptorPool stays alive. + PyDescriptorPool* pool = GetDescriptorPool_FromPool( + GetFileDescriptor(descriptor)->pool()); + if (pool == NULL) { + // Don't DECREF, the object is not fully initialized. + PyObject_Del(py_descriptor); + return NULL; + } + Py_INCREF(pool); + py_descriptor->pool = pool; + + PyObject_GC_Track(py_descriptor); + + if (was_created) { + *was_created = true; + } + return reinterpret_cast<PyObject*>(py_descriptor); +} + +static void Dealloc(PyObject* pself) { + PyBaseDescriptor* self = reinterpret_cast<PyBaseDescriptor*>(pself); + // Remove from interned dictionary + interned_descriptors->erase(self->descriptor); + Py_CLEAR(self->pool); + Py_TYPE(self)->tp_free(pself); +} + +static int GcTraverse(PyObject* pself, visitproc visit, void* arg) { + PyBaseDescriptor* self = reinterpret_cast<PyBaseDescriptor*>(pself); + Py_VISIT(self->pool); + return 0; +} + +static int GcClear(PyObject* pself) { + PyBaseDescriptor* self = reinterpret_cast<PyBaseDescriptor*>(pself); + Py_CLEAR(self->pool); + return 0; +} + +static PyGetSetDef Getters[] = { + {NULL} +}; + +PyTypeObject PyBaseDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME + ".DescriptorBase", // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize + 0, // tp_itemsize + (destructor)Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, // tp_flags + "Descriptors base class", // tp_doc + GcTraverse, // tp_traverse + GcClear, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + 0, // tp_methods + 0, // tp_members + Getters, // tp_getset +}; + +} // namespace descriptor + +const void* PyDescriptor_AsVoidPtr(PyObject* obj) { + if (!PyObject_TypeCheck(obj, &descriptor::PyBaseDescriptor_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a BaseDescriptor"); + return NULL; + } + return reinterpret_cast<PyBaseDescriptor*>(obj)->descriptor; +} + +namespace message_descriptor { + +// Unchecked accessor to the C++ pointer. +static const Descriptor* _GetDescriptor(PyBaseDescriptor* self) { + return reinterpret_cast<const Descriptor*>(self->descriptor); +} + +static PyObject* GetName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); +} + +static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->full_name()); +} + +static PyObject* GetFile(PyBaseDescriptor *self, void *closure) { + return PyFileDescriptor_FromDescriptor(_GetDescriptor(self)->file()); +} + +static PyObject* GetConcreteClass(PyBaseDescriptor* self, void *closure) { + // Returns the canonical class for the given descriptor. + // This is the class that was registered with the primary descriptor pool + // which contains this descriptor. + // This might not be the one you expect! For example the returned object does + // not know about extensions defined in a custom pool. + CMessageClass* concrete_class(message_factory::GetMessageClass( + GetDescriptorPool_FromPool( + _GetDescriptor(self)->file()->pool())->py_message_factory, + _GetDescriptor(self))); + Py_XINCREF(concrete_class); + return concrete_class->AsPyObject(); +} + +static PyObject* GetFieldsByName(PyBaseDescriptor* self, void *closure) { + return NewMessageFieldsByName(_GetDescriptor(self)); +} + +static PyObject* GetFieldsByCamelcaseName(PyBaseDescriptor* self, + void *closure) { + return NewMessageFieldsByCamelcaseName(_GetDescriptor(self)); +} + +static PyObject* GetFieldsByNumber(PyBaseDescriptor* self, void *closure) { + return NewMessageFieldsByNumber(_GetDescriptor(self)); +} + +static PyObject* GetFieldsSeq(PyBaseDescriptor* self, void *closure) { + return NewMessageFieldsSeq(_GetDescriptor(self)); +} + +static PyObject* GetNestedTypesByName(PyBaseDescriptor* self, void *closure) { + return NewMessageNestedTypesByName(_GetDescriptor(self)); +} + +static PyObject* GetNestedTypesSeq(PyBaseDescriptor* self, void *closure) { + return NewMessageNestedTypesSeq(_GetDescriptor(self)); +} + +static PyObject* GetExtensionsByName(PyBaseDescriptor* self, void *closure) { + return NewMessageExtensionsByName(_GetDescriptor(self)); +} + +static PyObject* GetExtensions(PyBaseDescriptor* self, void *closure) { + return NewMessageExtensionsSeq(_GetDescriptor(self)); +} + +static PyObject* GetEnumsSeq(PyBaseDescriptor* self, void *closure) { + return NewMessageEnumsSeq(_GetDescriptor(self)); +} + +static PyObject* GetEnumTypesByName(PyBaseDescriptor* self, void *closure) { + return NewMessageEnumsByName(_GetDescriptor(self)); +} + +static PyObject* GetEnumValuesByName(PyBaseDescriptor* self, void *closure) { + return NewMessageEnumValuesByName(_GetDescriptor(self)); +} + +static PyObject* GetOneofsByName(PyBaseDescriptor* self, void *closure) { + return NewMessageOneofsByName(_GetDescriptor(self)); +} + +static PyObject* GetOneofsSeq(PyBaseDescriptor* self, void *closure) { + return NewMessageOneofsSeq(_GetDescriptor(self)); +} + +static PyObject* IsExtendable(PyBaseDescriptor *self, void *closure) { + if (_GetDescriptor(self)->extension_range_count() > 0) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + +static PyObject* GetExtensionRanges(PyBaseDescriptor *self, void *closure) { + const Descriptor* descriptor = _GetDescriptor(self); + PyObject* range_list = PyList_New(descriptor->extension_range_count()); + + for (int i = 0; i < descriptor->extension_range_count(); i++) { + const Descriptor::ExtensionRange* range = descriptor->extension_range(i); + PyObject* start = PyInt_FromLong(range->start); + PyObject* end = PyInt_FromLong(range->end); + PyList_SetItem(range_list, i, PyTuple_Pack(2, start, end)); + } + + return range_list; +} + +static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) { + const Descriptor* containing_type = + _GetDescriptor(self)->containing_type(); + if (containing_type) { + return PyMessageDescriptor_FromDescriptor(containing_type); + } else { + Py_RETURN_NONE; + } +} + +static int SetContainingType(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("containing_type"); +} + +static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) { + const MessageOptions& options(_GetDescriptor(self)->options()); + if (&options != &MessageOptions::default_instance()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} +static int SetHasOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("has_options"); +} + +static PyObject* GetOptions(PyBaseDescriptor *self) { + return GetOrBuildOptions(_GetDescriptor(self)); +} + +static int SetOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_options"); +} + +static int SetSerializedOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_serialized_options"); +} + +static PyObject* CopyToProto(PyBaseDescriptor *self, PyObject *target) { + return CopyToPythonProto<DescriptorProto>(_GetDescriptor(self), target); +} + +static PyObject* EnumValueName(PyBaseDescriptor *self, PyObject *args) { + const char *enum_name; + int number; + if (!PyArg_ParseTuple(args, "si", &enum_name, &number)) + return NULL; + const EnumDescriptor *enum_type = + _GetDescriptor(self)->FindEnumTypeByName(enum_name); + if (enum_type == NULL) { + PyErr_SetString(PyExc_KeyError, enum_name); + return NULL; + } + const EnumValueDescriptor *enum_value = + enum_type->FindValueByNumber(number); + if (enum_value == NULL) { + PyErr_Format(PyExc_KeyError, "%d", number); + return NULL; + } + return PyString_FromCppString(enum_value->name()); +} + +static PyObject* GetSyntax(PyBaseDescriptor *self, void *closure) { + return PyString_InternFromString( + FileDescriptor::SyntaxName(_GetDescriptor(self)->file()->syntax())); +} + +static PyGetSetDef Getters[] = { + { "name", (getter)GetName, NULL, "Last name"}, + { "full_name", (getter)GetFullName, NULL, "Full name"}, + { "_concrete_class", (getter)GetConcreteClass, NULL, "concrete class"}, + { "file", (getter)GetFile, NULL, "File descriptor"}, + + { "fields", (getter)GetFieldsSeq, NULL, "Fields sequence"}, + { "fields_by_name", (getter)GetFieldsByName, NULL, "Fields by name"}, + { "fields_by_camelcase_name", (getter)GetFieldsByCamelcaseName, NULL, + "Fields by camelCase name"}, + { "fields_by_number", (getter)GetFieldsByNumber, NULL, "Fields by number"}, + { "nested_types", (getter)GetNestedTypesSeq, NULL, "Nested types sequence"}, + { "nested_types_by_name", (getter)GetNestedTypesByName, NULL, + "Nested types by name"}, + { "extensions", (getter)GetExtensions, NULL, "Extensions Sequence"}, + { "extensions_by_name", (getter)GetExtensionsByName, NULL, + "Extensions by name"}, + { "extension_ranges", (getter)GetExtensionRanges, NULL, "Extension ranges"}, + { "enum_types", (getter)GetEnumsSeq, NULL, "Enum sequence"}, + { "enum_types_by_name", (getter)GetEnumTypesByName, NULL, + "Enum types by name"}, + { "enum_values_by_name", (getter)GetEnumValuesByName, NULL, + "Enum values by name"}, + { "oneofs_by_name", (getter)GetOneofsByName, NULL, "Oneofs by name"}, + { "oneofs", (getter)GetOneofsSeq, NULL, "Oneofs by name"}, + { "containing_type", (getter)GetContainingType, (setter)SetContainingType, + "Containing type"}, + { "is_extendable", (getter)IsExtendable, (setter)NULL}, + { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"}, + { "_options", (getter)NULL, (setter)SetOptions, "Options"}, + { "_serialized_options", (getter)NULL, (setter)SetSerializedOptions, + "Serialized Options"}, + { "syntax", (getter)GetSyntax, (setter)NULL, "Syntax"}, + {NULL} +}; + +static PyMethodDef Methods[] = { + { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, }, + { "CopyToProto", (PyCFunction)CopyToProto, METH_O, }, + { "EnumValueName", (PyCFunction)EnumValueName, METH_VARARGS, }, + {NULL} +}; + +} // namespace message_descriptor + +PyTypeObject PyMessageDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".MessageDescriptor", // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A Message Descriptor", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + message_descriptor::Methods, // tp_methods + 0, // tp_members + message_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base +}; + +PyObject* PyMessageDescriptor_FromDescriptor( + const Descriptor* message_descriptor) { + return descriptor::NewInternedDescriptor( + &PyMessageDescriptor_Type, message_descriptor, NULL); +} + +const Descriptor* PyMessageDescriptor_AsDescriptor(PyObject* obj) { + if (!PyObject_TypeCheck(obj, &PyMessageDescriptor_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a MessageDescriptor"); + return NULL; + } + return reinterpret_cast<const Descriptor*>( + reinterpret_cast<PyBaseDescriptor*>(obj)->descriptor); +} + +namespace field_descriptor { + +// Unchecked accessor to the C++ pointer. +static const FieldDescriptor* _GetDescriptor( + PyBaseDescriptor *self) { + return reinterpret_cast<const FieldDescriptor*>(self->descriptor); +} + +static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->full_name()); +} + +static PyObject* GetName(PyBaseDescriptor *self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); +} + +static PyObject* GetCamelcaseName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->camelcase_name()); +} + +static PyObject* GetJsonName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->json_name()); +} + +static PyObject* GetFile(PyBaseDescriptor *self, void *closure) { + return PyFileDescriptor_FromDescriptor(_GetDescriptor(self)->file()); +} + +static PyObject* GetType(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->type()); +} + +static PyObject* GetCppType(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->cpp_type()); +} + +static PyObject* GetLabel(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->label()); +} + +static PyObject* GetNumber(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->number()); +} + +static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->index()); +} + +static PyObject* GetID(PyBaseDescriptor *self, void *closure) { + return PyLong_FromVoidPtr(self); +} + +static PyObject* IsExtension(PyBaseDescriptor *self, void *closure) { + return PyBool_FromLong(_GetDescriptor(self)->is_extension()); +} + +static PyObject* HasDefaultValue(PyBaseDescriptor *self, void *closure) { + return PyBool_FromLong(_GetDescriptor(self)->has_default_value()); +} + +static PyObject* GetDefaultValue(PyBaseDescriptor *self, void *closure) { + PyObject *result; + + if (_GetDescriptor(self)->is_repeated()) { + return PyList_New(0); + } + + + switch (_GetDescriptor(self)->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: { + int32_t value = _GetDescriptor(self)->default_value_int32(); + result = PyInt_FromLong(value); + break; + } + case FieldDescriptor::CPPTYPE_INT64: { + int64_t value = _GetDescriptor(self)->default_value_int64(); + result = PyLong_FromLongLong(value); + break; + } + case FieldDescriptor::CPPTYPE_UINT32: { + uint32_t value = _GetDescriptor(self)->default_value_uint32(); + result = PyInt_FromSize_t(value); + break; + } + case FieldDescriptor::CPPTYPE_UINT64: { + uint64_t value = _GetDescriptor(self)->default_value_uint64(); + result = PyLong_FromUnsignedLongLong(value); + break; + } + case FieldDescriptor::CPPTYPE_FLOAT: { + float value = _GetDescriptor(self)->default_value_float(); + result = PyFloat_FromDouble(value); + break; + } + case FieldDescriptor::CPPTYPE_DOUBLE: { + double value = _GetDescriptor(self)->default_value_double(); + result = PyFloat_FromDouble(value); + break; + } + case FieldDescriptor::CPPTYPE_BOOL: { + bool value = _GetDescriptor(self)->default_value_bool(); + result = PyBool_FromLong(value); + break; + } + case FieldDescriptor::CPPTYPE_STRING: { + const TProtoStringType& value = _GetDescriptor(self)->default_value_string(); + result = ToStringObject(_GetDescriptor(self), value); + break; + } + case FieldDescriptor::CPPTYPE_ENUM: { + const EnumValueDescriptor* value = + _GetDescriptor(self)->default_value_enum(); + result = PyInt_FromLong(value->number()); + break; + } + case FieldDescriptor::CPPTYPE_MESSAGE: { + Py_RETURN_NONE; + break; + } + default: + PyErr_Format(PyExc_NotImplementedError, "default value for %s", + _GetDescriptor(self)->full_name().c_str()); + return NULL; + } + return result; +} + +static PyObject* GetCDescriptor(PyObject *self, void *closure) { + Py_INCREF(self); + return self; +} + +static PyObject *GetEnumType(PyBaseDescriptor *self, void *closure) { + const EnumDescriptor* enum_type = _GetDescriptor(self)->enum_type(); + if (enum_type) { + return PyEnumDescriptor_FromDescriptor(enum_type); + } else { + Py_RETURN_NONE; + } +} + +static int SetEnumType(PyBaseDescriptor *self, PyObject *value, void *closure) { + return CheckCalledFromGeneratedFile("enum_type"); +} + +static PyObject *GetMessageType(PyBaseDescriptor *self, void *closure) { + const Descriptor* message_type = _GetDescriptor(self)->message_type(); + if (message_type) { + return PyMessageDescriptor_FromDescriptor(message_type); + } else { + Py_RETURN_NONE; + } +} + +static int SetMessageType(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("message_type"); +} + +static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) { + const Descriptor* containing_type = + _GetDescriptor(self)->containing_type(); + if (containing_type) { + return PyMessageDescriptor_FromDescriptor(containing_type); + } else { + Py_RETURN_NONE; + } +} + +static int SetContainingType(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("containing_type"); +} + +static PyObject* GetExtensionScope(PyBaseDescriptor *self, void *closure) { + const Descriptor* extension_scope = + _GetDescriptor(self)->extension_scope(); + if (extension_scope) { + return PyMessageDescriptor_FromDescriptor(extension_scope); + } else { + Py_RETURN_NONE; + } +} + +static PyObject* GetContainingOneof(PyBaseDescriptor *self, void *closure) { + const OneofDescriptor* containing_oneof = + _GetDescriptor(self)->containing_oneof(); + if (containing_oneof) { + return PyOneofDescriptor_FromDescriptor(containing_oneof); + } else { + Py_RETURN_NONE; + } +} + +static int SetContainingOneof(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("containing_oneof"); +} + +static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) { + const FieldOptions& options(_GetDescriptor(self)->options()); + if (&options != &FieldOptions::default_instance()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} +static int SetHasOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("has_options"); +} + +static PyObject* GetOptions(PyBaseDescriptor *self) { + return GetOrBuildOptions(_GetDescriptor(self)); +} + +static int SetOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_options"); +} + +static int SetSerializedOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_serialized_options"); +} + +static PyGetSetDef Getters[] = { + { "full_name", (getter)GetFullName, NULL, "Full name"}, + { "name", (getter)GetName, NULL, "Unqualified name"}, + { "camelcase_name", (getter)GetCamelcaseName, NULL, "Camelcase name"}, + { "json_name", (getter)GetJsonName, NULL, "Json name"}, + { "file", (getter)GetFile, NULL, "File Descriptor"}, + { "type", (getter)GetType, NULL, "C++ Type"}, + { "cpp_type", (getter)GetCppType, NULL, "C++ Type"}, + { "label", (getter)GetLabel, NULL, "Label"}, + { "number", (getter)GetNumber, NULL, "Number"}, + { "index", (getter)GetIndex, NULL, "Index"}, + { "default_value", (getter)GetDefaultValue, NULL, "Default Value"}, + { "has_default_value", (getter)HasDefaultValue}, + { "is_extension", (getter)IsExtension, NULL, "ID"}, + { "id", (getter)GetID, NULL, "ID"}, + { "_cdescriptor", (getter)GetCDescriptor, NULL, "HAACK REMOVE ME"}, + + { "message_type", (getter)GetMessageType, (setter)SetMessageType, + "Message type"}, + { "enum_type", (getter)GetEnumType, (setter)SetEnumType, "Enum type"}, + { "containing_type", (getter)GetContainingType, (setter)SetContainingType, + "Containing type"}, + { "extension_scope", (getter)GetExtensionScope, (setter)NULL, + "Extension scope"}, + { "containing_oneof", (getter)GetContainingOneof, (setter)SetContainingOneof, + "Containing oneof"}, + { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"}, + { "_options", (getter)NULL, (setter)SetOptions, "Options"}, + { "_serialized_options", (getter)NULL, (setter)SetSerializedOptions, + "Serialized Options"}, + {NULL} +}; + +static PyMethodDef Methods[] = { + { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, }, + {NULL} +}; + +} // namespace field_descriptor + +PyTypeObject PyFieldDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".FieldDescriptor", // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A Field Descriptor", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + field_descriptor::Methods, // tp_methods + 0, // tp_members + field_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base +}; + +PyObject* PyFieldDescriptor_FromDescriptor( + const FieldDescriptor* field_descriptor) { + return descriptor::NewInternedDescriptor( + &PyFieldDescriptor_Type, field_descriptor, NULL); +} + +const FieldDescriptor* PyFieldDescriptor_AsDescriptor(PyObject* obj) { + if (!PyObject_TypeCheck(obj, &PyFieldDescriptor_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a FieldDescriptor"); + return NULL; + } + return reinterpret_cast<const FieldDescriptor*>( + reinterpret_cast<PyBaseDescriptor*>(obj)->descriptor); +} + +namespace enum_descriptor { + +// Unchecked accessor to the C++ pointer. +static const EnumDescriptor* _GetDescriptor( + PyBaseDescriptor *self) { + return reinterpret_cast<const EnumDescriptor*>(self->descriptor); +} + +static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->full_name()); +} + +static PyObject* GetName(PyBaseDescriptor *self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); +} + +static PyObject* GetFile(PyBaseDescriptor *self, void *closure) { + return PyFileDescriptor_FromDescriptor(_GetDescriptor(self)->file()); +} + +static PyObject* GetEnumvaluesByName(PyBaseDescriptor* self, void *closure) { + return NewEnumValuesByName(_GetDescriptor(self)); +} + +static PyObject* GetEnumvaluesByNumber(PyBaseDescriptor* self, void *closure) { + return NewEnumValuesByNumber(_GetDescriptor(self)); +} + +static PyObject* GetEnumvaluesSeq(PyBaseDescriptor* self, void *closure) { + return NewEnumValuesSeq(_GetDescriptor(self)); +} + +static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) { + const Descriptor* containing_type = + _GetDescriptor(self)->containing_type(); + if (containing_type) { + return PyMessageDescriptor_FromDescriptor(containing_type); + } else { + Py_RETURN_NONE; + } +} + +static int SetContainingType(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("containing_type"); +} + + +static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) { + const EnumOptions& options(_GetDescriptor(self)->options()); + if (&options != &EnumOptions::default_instance()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} +static int SetHasOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("has_options"); +} + +static PyObject* GetOptions(PyBaseDescriptor *self) { + return GetOrBuildOptions(_GetDescriptor(self)); +} + +static int SetOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_options"); +} + +static int SetSerializedOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_serialized_options"); +} + +static PyObject* CopyToProto(PyBaseDescriptor *self, PyObject *target) { + return CopyToPythonProto<EnumDescriptorProto>(_GetDescriptor(self), target); +} + +static PyMethodDef Methods[] = { + { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, }, + { "CopyToProto", (PyCFunction)CopyToProto, METH_O, }, + {NULL} +}; + +static PyGetSetDef Getters[] = { + { "full_name", (getter)GetFullName, NULL, "Full name"}, + { "name", (getter)GetName, NULL, "last name"}, + { "file", (getter)GetFile, NULL, "File descriptor"}, + { "values", (getter)GetEnumvaluesSeq, NULL, "values"}, + { "values_by_name", (getter)GetEnumvaluesByName, NULL, + "Enum values by name"}, + { "values_by_number", (getter)GetEnumvaluesByNumber, NULL, + "Enum values by number"}, + + { "containing_type", (getter)GetContainingType, (setter)SetContainingType, + "Containing type"}, + { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"}, + { "_options", (getter)NULL, (setter)SetOptions, "Options"}, + { "_serialized_options", (getter)NULL, (setter)SetSerializedOptions, + "Serialized Options"}, + {NULL} +}; + +} // namespace enum_descriptor + +PyTypeObject PyEnumDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".EnumDescriptor", // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A Enum Descriptor", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + enum_descriptor::Methods, // tp_methods + 0, // tp_members + enum_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base +}; + +PyObject* PyEnumDescriptor_FromDescriptor( + const EnumDescriptor* enum_descriptor) { + return descriptor::NewInternedDescriptor( + &PyEnumDescriptor_Type, enum_descriptor, NULL); +} + +const EnumDescriptor* PyEnumDescriptor_AsDescriptor(PyObject* obj) { + if (!PyObject_TypeCheck(obj, &PyEnumDescriptor_Type)) { + PyErr_SetString(PyExc_TypeError, "Not an EnumDescriptor"); + return NULL; + } + return reinterpret_cast<const EnumDescriptor*>( + reinterpret_cast<PyBaseDescriptor*>(obj)->descriptor); +} + +namespace enumvalue_descriptor { + +// Unchecked accessor to the C++ pointer. +static const EnumValueDescriptor* _GetDescriptor( + PyBaseDescriptor *self) { + return reinterpret_cast<const EnumValueDescriptor*>(self->descriptor); +} + +static PyObject* GetName(PyBaseDescriptor *self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); +} + +static PyObject* GetNumber(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->number()); +} + +static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->index()); +} + +static PyObject* GetType(PyBaseDescriptor *self, void *closure) { + return PyEnumDescriptor_FromDescriptor(_GetDescriptor(self)->type()); +} + +static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) { + const EnumValueOptions& options(_GetDescriptor(self)->options()); + if (&options != &EnumValueOptions::default_instance()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} +static int SetHasOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("has_options"); +} + +static PyObject* GetOptions(PyBaseDescriptor *self) { + return GetOrBuildOptions(_GetDescriptor(self)); +} + +static int SetOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_options"); +} + +static int SetSerializedOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_serialized_options"); +} + +static PyGetSetDef Getters[] = { + { "name", (getter)GetName, NULL, "name"}, + { "number", (getter)GetNumber, NULL, "number"}, + { "index", (getter)GetIndex, NULL, "index"}, + { "type", (getter)GetType, NULL, "index"}, + + { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"}, + { "_options", (getter)NULL, (setter)SetOptions, "Options"}, + { "_serialized_options", (getter)NULL, (setter)SetSerializedOptions, + "Serialized Options"}, + {NULL} +}; + +static PyMethodDef Methods[] = { + { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, }, + {NULL} +}; + +} // namespace enumvalue_descriptor + +PyTypeObject PyEnumValueDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".EnumValueDescriptor", // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A EnumValue Descriptor", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + enumvalue_descriptor::Methods, // tp_methods + 0, // tp_members + enumvalue_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base +}; + +PyObject* PyEnumValueDescriptor_FromDescriptor( + const EnumValueDescriptor* enumvalue_descriptor) { + return descriptor::NewInternedDescriptor( + &PyEnumValueDescriptor_Type, enumvalue_descriptor, NULL); +} + +namespace file_descriptor { + +// Unchecked accessor to the C++ pointer. +static const FileDescriptor* _GetDescriptor(PyFileDescriptor *self) { + return reinterpret_cast<const FileDescriptor*>(self->base.descriptor); +} + +static void Dealloc(PyFileDescriptor* self) { + Py_XDECREF(self->serialized_pb); + descriptor::Dealloc(reinterpret_cast<PyObject*>(self)); +} + +static PyObject* GetPool(PyFileDescriptor *self, void *closure) { + PyObject* pool = reinterpret_cast<PyObject*>( + GetDescriptorPool_FromPool(_GetDescriptor(self)->pool())); + Py_XINCREF(pool); + return pool; +} + +static PyObject* GetName(PyFileDescriptor *self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); +} + +static PyObject* GetPackage(PyFileDescriptor *self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->package()); +} + +static PyObject* GetSerializedPb(PyFileDescriptor *self, void *closure) { + PyObject *serialized_pb = self->serialized_pb; + if (serialized_pb != NULL) { + Py_INCREF(serialized_pb); + return serialized_pb; + } + FileDescriptorProto file_proto; + _GetDescriptor(self)->CopyTo(&file_proto); + TProtoStringType contents; + file_proto.SerializePartialToString(&contents); + self->serialized_pb = PyBytes_FromStringAndSize( + contents.c_str(), contents.size()); + if (self->serialized_pb == NULL) { + return NULL; + } + Py_INCREF(self->serialized_pb); + return self->serialized_pb; +} + +static PyObject* GetMessageTypesByName(PyFileDescriptor* self, void *closure) { + return NewFileMessageTypesByName(_GetDescriptor(self)); +} + +static PyObject* GetEnumTypesByName(PyFileDescriptor* self, void *closure) { + return NewFileEnumTypesByName(_GetDescriptor(self)); +} + +static PyObject* GetExtensionsByName(PyFileDescriptor* self, void *closure) { + return NewFileExtensionsByName(_GetDescriptor(self)); +} + +static PyObject* GetServicesByName(PyFileDescriptor* self, void *closure) { + return NewFileServicesByName(_GetDescriptor(self)); +} + +static PyObject* GetDependencies(PyFileDescriptor* self, void *closure) { + return NewFileDependencies(_GetDescriptor(self)); +} + +static PyObject* GetPublicDependencies(PyFileDescriptor* self, void *closure) { + return NewFilePublicDependencies(_GetDescriptor(self)); +} + +static PyObject* GetHasOptions(PyFileDescriptor *self, void *closure) { + const FileOptions& options(_GetDescriptor(self)->options()); + if (&options != &FileOptions::default_instance()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} +static int SetHasOptions(PyFileDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("has_options"); +} + +static PyObject* GetOptions(PyFileDescriptor *self) { + return GetOrBuildOptions(_GetDescriptor(self)); +} + +static int SetOptions(PyFileDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_options"); +} + +static int SetSerializedOptions(PyFileDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_serialized_options"); +} + +static PyObject* GetSyntax(PyFileDescriptor *self, void *closure) { + return PyString_InternFromString( + FileDescriptor::SyntaxName(_GetDescriptor(self)->syntax())); +} + +static PyObject* CopyToProto(PyFileDescriptor *self, PyObject *target) { + return CopyToPythonProto<FileDescriptorProto>(_GetDescriptor(self), target); +} + +static PyGetSetDef Getters[] = { + { "pool", (getter)GetPool, NULL, "pool"}, + { "name", (getter)GetName, NULL, "name"}, + { "package", (getter)GetPackage, NULL, "package"}, + { "serialized_pb", (getter)GetSerializedPb}, + { "message_types_by_name", (getter)GetMessageTypesByName, NULL, + "Messages by name"}, + { "enum_types_by_name", (getter)GetEnumTypesByName, NULL, "Enums by name"}, + { "extensions_by_name", (getter)GetExtensionsByName, NULL, + "Extensions by name"}, + { "services_by_name", (getter)GetServicesByName, NULL, "Services by name"}, + { "dependencies", (getter)GetDependencies, NULL, "Dependencies"}, + { "public_dependencies", (getter)GetPublicDependencies, NULL, "Dependencies"}, + + { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"}, + { "_options", (getter)NULL, (setter)SetOptions, "Options"}, + { "_serialized_options", (getter)NULL, (setter)SetSerializedOptions, + "Serialized Options"}, + { "syntax", (getter)GetSyntax, (setter)NULL, "Syntax"}, + {NULL} +}; + +static PyMethodDef Methods[] = { + { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, }, + { "CopyToProto", (PyCFunction)CopyToProto, METH_O, }, + {NULL} +}; + +} // namespace file_descriptor + +PyTypeObject PyFileDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME + ".FileDescriptor", // tp_name + sizeof(PyFileDescriptor), // tp_basicsize + 0, // tp_itemsize + (destructor)file_descriptor::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A File Descriptor", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + file_descriptor::Methods, // tp_methods + 0, // tp_members + file_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + 0, // tp_new + PyObject_GC_Del, // tp_free +}; + +PyObject* PyFileDescriptor_FromDescriptor( + const FileDescriptor* file_descriptor) { + return PyFileDescriptor_FromDescriptorWithSerializedPb(file_descriptor, + NULL); +} + +PyObject* PyFileDescriptor_FromDescriptorWithSerializedPb( + const FileDescriptor* file_descriptor, PyObject *serialized_pb) { + bool was_created; + PyObject* py_descriptor = descriptor::NewInternedDescriptor( + &PyFileDescriptor_Type, file_descriptor, &was_created); + if (py_descriptor == NULL) { + return NULL; + } + if (was_created) { + PyFileDescriptor* cfile_descriptor = + reinterpret_cast<PyFileDescriptor*>(py_descriptor); + Py_XINCREF(serialized_pb); + cfile_descriptor->serialized_pb = serialized_pb; + } + // TODO(amauryfa): In the case of a cached object, check that serialized_pb + // is the same as before. + + return py_descriptor; +} + +const FileDescriptor* PyFileDescriptor_AsDescriptor(PyObject* obj) { + if (!PyObject_TypeCheck(obj, &PyFileDescriptor_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a FileDescriptor"); + return NULL; + } + return reinterpret_cast<const FileDescriptor*>( + reinterpret_cast<PyBaseDescriptor*>(obj)->descriptor); +} + +namespace oneof_descriptor { + +// Unchecked accessor to the C++ pointer. +static const OneofDescriptor* _GetDescriptor( + PyBaseDescriptor *self) { + return reinterpret_cast<const OneofDescriptor*>(self->descriptor); +} + +static PyObject* GetName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); +} + +static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->full_name()); +} + +static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->index()); +} + +static PyObject* GetFields(PyBaseDescriptor* self, void *closure) { + return NewOneofFieldsSeq(_GetDescriptor(self)); +} + +static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) { + const Descriptor* containing_type = + _GetDescriptor(self)->containing_type(); + if (containing_type) { + return PyMessageDescriptor_FromDescriptor(containing_type); + } else { + Py_RETURN_NONE; + } +} + +static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) { + const OneofOptions& options(_GetDescriptor(self)->options()); + if (&options != &OneofOptions::default_instance()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} +static int SetHasOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("has_options"); +} + +static PyObject* GetOptions(PyBaseDescriptor *self) { + return GetOrBuildOptions(_GetDescriptor(self)); +} + +static int SetOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_options"); +} + +static int SetSerializedOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_serialized_options"); +} + +static PyGetSetDef Getters[] = { + { "name", (getter)GetName, NULL, "Name"}, + { "full_name", (getter)GetFullName, NULL, "Full name"}, + { "index", (getter)GetIndex, NULL, "Index"}, + + { "containing_type", (getter)GetContainingType, NULL, "Containing type"}, + { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"}, + { "_options", (getter)NULL, (setter)SetOptions, "Options"}, + { "_serialized_options", (getter)NULL, (setter)SetSerializedOptions, + "Serialized Options"}, + { "fields", (getter)GetFields, NULL, "Fields"}, + {NULL} +}; + +static PyMethodDef Methods[] = { + { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS }, + {NULL} +}; + +} // namespace oneof_descriptor + +PyTypeObject PyOneofDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".OneofDescriptor", // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A Oneof Descriptor", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + oneof_descriptor::Methods, // tp_methods + 0, // tp_members + oneof_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base +}; + +PyObject* PyOneofDescriptor_FromDescriptor( + const OneofDescriptor* oneof_descriptor) { + return descriptor::NewInternedDescriptor( + &PyOneofDescriptor_Type, oneof_descriptor, NULL); +} + +namespace service_descriptor { + +// Unchecked accessor to the C++ pointer. +static const ServiceDescriptor* _GetDescriptor( + PyBaseDescriptor *self) { + return reinterpret_cast<const ServiceDescriptor*>(self->descriptor); +} + +static PyObject* GetName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); +} + +static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->full_name()); +} + +static PyObject* GetFile(PyBaseDescriptor *self, void *closure) { + return PyFileDescriptor_FromDescriptor(_GetDescriptor(self)->file()); +} + +static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->index()); +} + +static PyObject* GetMethods(PyBaseDescriptor* self, void *closure) { + return NewServiceMethodsSeq(_GetDescriptor(self)); +} + +static PyObject* GetMethodsByName(PyBaseDescriptor* self, void *closure) { + return NewServiceMethodsByName(_GetDescriptor(self)); +} + +static PyObject* FindMethodByName(PyBaseDescriptor *self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const MethodDescriptor* method_descriptor = + _GetDescriptor(self)->FindMethodByName(StringParam(name, name_size)); + if (method_descriptor == NULL) { + PyErr_Format(PyExc_KeyError, "Couldn't find method %.200s", name); + return NULL; + } + + return PyMethodDescriptor_FromDescriptor(method_descriptor); +} + +static PyObject* GetOptions(PyBaseDescriptor *self) { + return GetOrBuildOptions(_GetDescriptor(self)); +} + +static PyObject* CopyToProto(PyBaseDescriptor *self, PyObject *target) { + return CopyToPythonProto<ServiceDescriptorProto>(_GetDescriptor(self), + target); +} + +static PyGetSetDef Getters[] = { + { "name", (getter)GetName, NULL, "Name", NULL}, + { "full_name", (getter)GetFullName, NULL, "Full name", NULL}, + { "file", (getter)GetFile, NULL, "File descriptor"}, + { "index", (getter)GetIndex, NULL, "Index", NULL}, + + { "methods", (getter)GetMethods, NULL, "Methods", NULL}, + { "methods_by_name", (getter)GetMethodsByName, NULL, "Methods by name", NULL}, + {NULL} +}; + +static PyMethodDef Methods[] = { + { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS }, + { "CopyToProto", (PyCFunction)CopyToProto, METH_O, }, + { "FindMethodByName", (PyCFunction)FindMethodByName, METH_O }, + {NULL} +}; + +} // namespace service_descriptor + +PyTypeObject PyServiceDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".ServiceDescriptor", // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A Service Descriptor", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + service_descriptor::Methods, // tp_methods + 0, // tp_members + service_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base +}; + +PyObject* PyServiceDescriptor_FromDescriptor( + const ServiceDescriptor* service_descriptor) { + return descriptor::NewInternedDescriptor( + &PyServiceDescriptor_Type, service_descriptor, NULL); +} + +const ServiceDescriptor* PyServiceDescriptor_AsDescriptor(PyObject* obj) { + if (!PyObject_TypeCheck(obj, &PyServiceDescriptor_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a ServiceDescriptor"); + return NULL; + } + return reinterpret_cast<const ServiceDescriptor*>( + reinterpret_cast<PyBaseDescriptor*>(obj)->descriptor); +} + +namespace method_descriptor { + +// Unchecked accessor to the C++ pointer. +static const MethodDescriptor* _GetDescriptor( + PyBaseDescriptor *self) { + return reinterpret_cast<const MethodDescriptor*>(self->descriptor); +} + +static PyObject* GetName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); +} + +static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->full_name()); +} + +static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->index()); +} + +static PyObject* GetContainingService(PyBaseDescriptor *self, void *closure) { + const ServiceDescriptor* containing_service = + _GetDescriptor(self)->service(); + return PyServiceDescriptor_FromDescriptor(containing_service); +} + +static PyObject* GetInputType(PyBaseDescriptor *self, void *closure) { + const Descriptor* input_type = _GetDescriptor(self)->input_type(); + return PyMessageDescriptor_FromDescriptor(input_type); +} + +static PyObject* GetOutputType(PyBaseDescriptor *self, void *closure) { + const Descriptor* output_type = _GetDescriptor(self)->output_type(); + return PyMessageDescriptor_FromDescriptor(output_type); +} + +static PyObject* GetOptions(PyBaseDescriptor *self) { + return GetOrBuildOptions(_GetDescriptor(self)); +} + +static PyObject* CopyToProto(PyBaseDescriptor *self, PyObject *target) { + return CopyToPythonProto<MethodDescriptorProto>(_GetDescriptor(self), target); +} + +static PyGetSetDef Getters[] = { + { "name", (getter)GetName, NULL, "Name", NULL}, + { "full_name", (getter)GetFullName, NULL, "Full name", NULL}, + { "index", (getter)GetIndex, NULL, "Index", NULL}, + { "containing_service", (getter)GetContainingService, NULL, + "Containing service", NULL}, + { "input_type", (getter)GetInputType, NULL, "Input type", NULL}, + { "output_type", (getter)GetOutputType, NULL, "Output type", NULL}, + {NULL} +}; + +static PyMethodDef Methods[] = { + { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, }, + { "CopyToProto", (PyCFunction)CopyToProto, METH_O, }, + {NULL} +}; + +} // namespace method_descriptor + +PyTypeObject PyMethodDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".MethodDescriptor", // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A Method Descriptor", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + method_descriptor::Methods, // tp_methods + 0, // tp_members + method_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base +}; + +PyObject* PyMethodDescriptor_FromDescriptor( + const MethodDescriptor* method_descriptor) { + return descriptor::NewInternedDescriptor( + &PyMethodDescriptor_Type, method_descriptor, NULL); +} + +// Add a enum values to a type dictionary. +static bool AddEnumValues(PyTypeObject *type, + const EnumDescriptor* enum_descriptor) { + for (int i = 0; i < enum_descriptor->value_count(); ++i) { + const EnumValueDescriptor* value = enum_descriptor->value(i); + ScopedPyObjectPtr obj(PyInt_FromLong(value->number())); + if (obj == NULL) { + return false; + } + if (PyDict_SetItemString(type->tp_dict, value->name().c_str(), obj.get()) < + 0) { + return false; + } + } + return true; +} + +static bool AddIntConstant(PyTypeObject *type, const char* name, int value) { + ScopedPyObjectPtr obj(PyInt_FromLong(value)); + if (PyDict_SetItemString(type->tp_dict, name, obj.get()) < 0) { + return false; + } + return true; +} + + +bool InitDescriptor() { + if (PyType_Ready(&PyMessageDescriptor_Type) < 0) + return false; + + if (PyType_Ready(&PyFieldDescriptor_Type) < 0) + return false; + + if (!AddEnumValues(&PyFieldDescriptor_Type, + FieldDescriptorProto::Label_descriptor())) { + return false; + } + if (!AddEnumValues(&PyFieldDescriptor_Type, + FieldDescriptorProto::Type_descriptor())) { + return false; + } +#define ADD_FIELDDESC_CONSTANT(NAME) AddIntConstant( \ + &PyFieldDescriptor_Type, #NAME, FieldDescriptor::NAME) + if (!ADD_FIELDDESC_CONSTANT(CPPTYPE_INT32) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_INT64) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_UINT32) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_UINT64) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_DOUBLE) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_FLOAT) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_BOOL) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_ENUM) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_STRING) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_MESSAGE)) { + return false; + } +#undef ADD_FIELDDESC_CONSTANT + + if (PyType_Ready(&PyEnumDescriptor_Type) < 0) + return false; + + if (PyType_Ready(&PyEnumValueDescriptor_Type) < 0) + return false; + + if (PyType_Ready(&PyFileDescriptor_Type) < 0) + return false; + + if (PyType_Ready(&PyOneofDescriptor_Type) < 0) + return false; + + if (PyType_Ready(&PyServiceDescriptor_Type) < 0) + return false; + + if (PyType_Ready(&PyMethodDescriptor_Type) < 0) + return false; + + if (!InitDescriptorMappingTypes()) + return false; + + // Initialize globals defined in this file. + interned_descriptors = new std::unordered_map<const void*, PyObject*>; + + return true; +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor.h b/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor.h new file mode 100644 index 0000000000..47efbe35d7 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor.h @@ -0,0 +1,107 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: petar@google.com (Petar Petrov) + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_H__ + +#include <Python.h> + +#include <google/protobuf/descriptor.h> + +namespace google { +namespace protobuf { +namespace python { + +// Should match the type of ConstStringParam. +using StringParam = TProtoStringType; + +extern PyTypeObject PyMessageDescriptor_Type; +extern PyTypeObject PyFieldDescriptor_Type; +extern PyTypeObject PyEnumDescriptor_Type; +extern PyTypeObject PyEnumValueDescriptor_Type; +extern PyTypeObject PyFileDescriptor_Type; +extern PyTypeObject PyOneofDescriptor_Type; +extern PyTypeObject PyServiceDescriptor_Type; +extern PyTypeObject PyMethodDescriptor_Type; + +// Wraps a Descriptor in a Python object. +// The C++ pointer is usually borrowed from the global DescriptorPool. +// In any case, it must stay alive as long as the Python object. +// Returns a new reference. +PyObject* PyMessageDescriptor_FromDescriptor(const Descriptor* descriptor); +PyObject* PyFieldDescriptor_FromDescriptor(const FieldDescriptor* descriptor); +PyObject* PyEnumDescriptor_FromDescriptor(const EnumDescriptor* descriptor); +PyObject* PyEnumValueDescriptor_FromDescriptor( + const EnumValueDescriptor* descriptor); +PyObject* PyOneofDescriptor_FromDescriptor(const OneofDescriptor* descriptor); +PyObject* PyFileDescriptor_FromDescriptor( + const FileDescriptor* file_descriptor); +PyObject* PyServiceDescriptor_FromDescriptor( + const ServiceDescriptor* descriptor); +PyObject* PyMethodDescriptor_FromDescriptor( + const MethodDescriptor* descriptor); + +// Alternate constructor of PyFileDescriptor, used when we already have a +// serialized FileDescriptorProto that can be cached. +// Returns a new reference. +PyObject* PyFileDescriptor_FromDescriptorWithSerializedPb( + const FileDescriptor* file_descriptor, PyObject* serialized_pb); + +// Return the C++ descriptor pointer. +// This function checks the parameter type; on error, return NULL with a Python +// exception set. +const Descriptor* PyMessageDescriptor_AsDescriptor(PyObject* obj); +const FieldDescriptor* PyFieldDescriptor_AsDescriptor(PyObject* obj); +const EnumDescriptor* PyEnumDescriptor_AsDescriptor(PyObject* obj); +const FileDescriptor* PyFileDescriptor_AsDescriptor(PyObject* obj); +const ServiceDescriptor* PyServiceDescriptor_AsDescriptor(PyObject* obj); + +// Returns the raw C++ pointer. +const void* PyDescriptor_AsVoidPtr(PyObject* obj); + +// Check that the calling Python code is the global scope of a _pb2.py module. +// This function is used to support the current code generated by the proto +// compiler, which insists on modifying descriptors after they have been +// created. +// +// stacklevel indicates which Python frame should be the _pb2.py module. +// +// Don't use this function outside descriptor classes. +bool _CalledFromGeneratedFile(int stacklevel); + +bool InitDescriptor(); + +} // namespace python +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_H__ diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor_containers.cc b/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor_containers.cc new file mode 100644 index 0000000000..7990604ff0 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor_containers.cc @@ -0,0 +1,1793 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Mappings and Sequences of descriptors. +// Used by Descriptor.fields_by_name, EnumDescriptor.values... +// +// They avoid the allocation of a full dictionary or a full list: they simply +// store a pointer to the parent descriptor, use the C++ Descriptor methods (see +// net/proto2/public/descriptor.h) to retrieve other descriptors, and create +// Python objects on the fly. +// +// The containers fully conform to abc.Mapping and abc.Sequence, and behave just +// like read-only dictionaries and lists. +// +// Because the interface of C++ Descriptors is quite regular, this file actually +// defines only three types, the exact behavior of a container is controlled by +// a DescriptorContainerDef structure, which contains functions that uses the +// public Descriptor API. +// +// Note: This DescriptorContainerDef is similar to the "virtual methods table" +// that a C++ compiler generates for a class. We have to make it explicit +// because the Python API is based on C, and does not play well with C++ +// inheritance. + +#include <Python.h> + +#include <google/protobuf/descriptor.h> +#include <google/protobuf/pyext/descriptor_containers.h> +#include <google/protobuf/pyext/descriptor_pool.h> +#include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> + +#if PY_MAJOR_VERSION >= 3 + #define PyString_FromStringAndSize PyUnicode_FromStringAndSize + #define PyString_FromFormat PyUnicode_FromFormat + #define PyInt_FromLong PyLong_FromLong + #if PY_VERSION_HEX < 0x03030000 + #error "Python 3.0 - 3.2 are not supported." + #endif +#define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>( \ + PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \ + ? -1 \ + : 0) \ + : PyBytes_AsStringAndSize(ob, (charpp), (sizep))) +#endif + +namespace google { +namespace protobuf { +namespace python { + +struct PyContainer; + +typedef int (*CountMethod)(PyContainer* self); +typedef const void* (*GetByIndexMethod)(PyContainer* self, int index); +typedef const void* (*GetByNameMethod)(PyContainer* self, + ConstStringParam name); +typedef const void* (*GetByCamelcaseNameMethod)(PyContainer* self, + ConstStringParam name); +typedef const void* (*GetByNumberMethod)(PyContainer* self, int index); +typedef PyObject* (*NewObjectFromItemMethod)(const void* descriptor); +typedef const TProtoStringType& (*GetItemNameMethod)(const void* descriptor); +typedef const TProtoStringType& (*GetItemCamelcaseNameMethod)( + const void* descriptor); +typedef int (*GetItemNumberMethod)(const void* descriptor); +typedef int (*GetItemIndexMethod)(const void* descriptor); + +struct DescriptorContainerDef { + const char* mapping_name; + // Returns the number of items in the container. + CountMethod count_fn; + // Retrieve item by index (usually the order of declaration in the proto file) + // Used by sequences, but also iterators. 0 <= index < Count(). + GetByIndexMethod get_by_index_fn; + // Retrieve item by name (usually a call to some 'FindByName' method). + // Used by "by_name" mappings. + GetByNameMethod get_by_name_fn; + // Retrieve item by camelcase name (usually a call to some + // 'FindByCamelcaseName' method). Used by "by_camelcase_name" mappings. + GetByCamelcaseNameMethod get_by_camelcase_name_fn; + // Retrieve item by declared number (field tag, or enum value). + // Used by "by_number" mappings. + GetByNumberMethod get_by_number_fn; + // Converts a item C++ descriptor to a Python object. Returns a new reference. + NewObjectFromItemMethod new_object_from_item_fn; + // Retrieve the name of an item. Used by iterators on "by_name" mappings. + GetItemNameMethod get_item_name_fn; + // Retrieve the camelcase name of an item. Used by iterators on + // "by_camelcase_name" mappings. + GetItemCamelcaseNameMethod get_item_camelcase_name_fn; + // Retrieve the number of an item. Used by iterators on "by_number" mappings. + GetItemNumberMethod get_item_number_fn; + // Retrieve the index of an item for the container type. + // Used by "__contains__". + // If not set, "x in sequence" will do a linear search. + GetItemIndexMethod get_item_index_fn; +}; + +struct PyContainer { + PyObject_HEAD + + // The proto2 descriptor this container belongs to the global DescriptorPool. + const void* descriptor; + + // A pointer to a static structure with function pointers that control the + // behavior of the container. Very similar to the table of virtual functions + // of a C++ class. + const DescriptorContainerDef* container_def; + + // The kind of container: list, or dict by name or value. + enum ContainerKind { + KIND_SEQUENCE, + KIND_BYNAME, + KIND_BYCAMELCASENAME, + KIND_BYNUMBER, + } kind; +}; + +struct PyContainerIterator { + PyObject_HEAD + + // The container we are iterating over. Own a reference. + PyContainer* container; + + // The current index in the iterator. + int index; + + // The kind of container: list, or dict by name or value. + enum IterKind { + KIND_ITERKEY, + KIND_ITERVALUE, + KIND_ITERITEM, + KIND_ITERVALUE_REVERSED, // For sequences + } kind; +}; + +namespace descriptor { + +// Returns the C++ item descriptor for a given Python key. +// When the descriptor is found, return true and set *item. +// When the descriptor is not found, return true, but set *item to NULL. +// On error, returns false with an exception set. +static bool _GetItemByKey(PyContainer* self, PyObject* key, const void** item) { + switch (self->kind) { + case PyContainer::KIND_BYNAME: + { + char* name; + Py_ssize_t name_size; + if (PyString_AsStringAndSize(key, &name, &name_size) < 0) { + if (PyErr_ExceptionMatches(PyExc_TypeError)) { + // Not a string, cannot be in the container. + PyErr_Clear(); + *item = NULL; + return true; + } + return false; + } + *item = self->container_def->get_by_name_fn( + self, StringParam(name, name_size)); + return true; + } + case PyContainer::KIND_BYCAMELCASENAME: + { + char* camelcase_name; + Py_ssize_t name_size; + if (PyString_AsStringAndSize(key, &camelcase_name, &name_size) < 0) { + if (PyErr_ExceptionMatches(PyExc_TypeError)) { + // Not a string, cannot be in the container. + PyErr_Clear(); + *item = NULL; + return true; + } + return false; + } + *item = self->container_def->get_by_camelcase_name_fn( + self, StringParam(camelcase_name, name_size)); + return true; + } + case PyContainer::KIND_BYNUMBER: + { + Py_ssize_t number = PyNumber_AsSsize_t(key, NULL); + if (number == -1 && PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_TypeError)) { + // Not a number, cannot be in the container. + PyErr_Clear(); + *item = NULL; + return true; + } + return false; + } + *item = self->container_def->get_by_number_fn(self, number); + return true; + } + default: + PyErr_SetNone(PyExc_NotImplementedError); + return false; + } +} + +// Returns the key of the object at the given index. +// Used when iterating over mappings. +static PyObject* _NewKey_ByIndex(PyContainer* self, Py_ssize_t index) { + const void* item = self->container_def->get_by_index_fn(self, index); + switch (self->kind) { + case PyContainer::KIND_BYNAME: + { + const TProtoStringType& name(self->container_def->get_item_name_fn(item)); + return PyString_FromStringAndSize(name.c_str(), name.size()); + } + case PyContainer::KIND_BYCAMELCASENAME: + { + const TProtoStringType& name( + self->container_def->get_item_camelcase_name_fn(item)); + return PyString_FromStringAndSize(name.c_str(), name.size()); + } + case PyContainer::KIND_BYNUMBER: + { + int value = self->container_def->get_item_number_fn(item); + return PyInt_FromLong(value); + } + default: + PyErr_SetNone(PyExc_NotImplementedError); + return NULL; + } +} + +// Returns the object at the given index. +// Also used when iterating over mappings. +static PyObject* _NewObj_ByIndex(PyContainer* self, Py_ssize_t index) { + return self->container_def->new_object_from_item_fn( + self->container_def->get_by_index_fn(self, index)); +} + +static Py_ssize_t Length(PyContainer* self) { + return self->container_def->count_fn(self); +} + +// The DescriptorMapping type. + +static PyObject* Subscript(PyContainer* self, PyObject* key) { + const void* item = NULL; + if (!_GetItemByKey(self, key, &item)) { + return NULL; + } + if (!item) { + PyErr_SetObject(PyExc_KeyError, key); + return NULL; + } + return self->container_def->new_object_from_item_fn(item); +} + +static int AssSubscript(PyContainer* self, PyObject* key, PyObject* value) { + if (_CalledFromGeneratedFile(0)) { + return 0; + } + PyErr_Format(PyExc_TypeError, + "'%.200s' object does not support item assignment", + Py_TYPE(self)->tp_name); + return -1; +} + +static PyMappingMethods MappingMappingMethods = { + (lenfunc)Length, // mp_length + (binaryfunc)Subscript, // mp_subscript + (objobjargproc)AssSubscript, // mp_ass_subscript +}; + +static int Contains(PyContainer* self, PyObject* key) { + const void* item = NULL; + if (!_GetItemByKey(self, key, &item)) { + return -1; + } + if (item) { + return 1; + } else { + return 0; + } +} + +static PyObject* ContainerRepr(PyContainer* self) { + const char* kind = ""; + switch (self->kind) { + case PyContainer::KIND_SEQUENCE: + kind = "sequence"; + break; + case PyContainer::KIND_BYNAME: + kind = "mapping by name"; + break; + case PyContainer::KIND_BYCAMELCASENAME: + kind = "mapping by camelCase name"; + break; + case PyContainer::KIND_BYNUMBER: + kind = "mapping by number"; + break; + } + return PyString_FromFormat( + "<%s %s>", self->container_def->mapping_name, kind); +} + +extern PyTypeObject DescriptorMapping_Type; +extern PyTypeObject DescriptorSequence_Type; + +// A sequence container can only be equal to another sequence container, or (for +// backward compatibility) to a list containing the same items. +// Returns 1 if equal, 0 if unequal, -1 on error. +static int DescriptorSequence_Equal(PyContainer* self, PyObject* other) { + // Check the identity of C++ pointers. + if (PyObject_TypeCheck(other, &DescriptorSequence_Type)) { + PyContainer* other_container = reinterpret_cast<PyContainer*>(other); + if (self->descriptor == other_container->descriptor && + self->container_def == other_container->container_def && + self->kind == other_container->kind) { + return 1; + } else { + return 0; + } + } + + // If other is a list + if (PyList_Check(other)) { + // return list(self) == other + int size = Length(self); + if (size != PyList_Size(other)) { + return false; + } + for (int index = 0; index < size; index++) { + ScopedPyObjectPtr value1(_NewObj_ByIndex(self, index)); + if (value1 == NULL) { + return -1; + } + PyObject* value2 = PyList_GetItem(other, index); + if (value2 == NULL) { + return -1; + } + int cmp = PyObject_RichCompareBool(value1.get(), value2, Py_EQ); + if (cmp != 1) // error or not equal + return cmp; + } + // All items were found and equal + return 1; + } + + // Any other object is different. + return 0; +} + +// A mapping container can only be equal to another mapping container, or (for +// backward compatibility) to a dict containing the same items. +// Returns 1 if equal, 0 if unequal, -1 on error. +static int DescriptorMapping_Equal(PyContainer* self, PyObject* other) { + // Check the identity of C++ pointers. + if (PyObject_TypeCheck(other, &DescriptorMapping_Type)) { + PyContainer* other_container = reinterpret_cast<PyContainer*>(other); + if (self->descriptor == other_container->descriptor && + self->container_def == other_container->container_def && + self->kind == other_container->kind) { + return 1; + } else { + return 0; + } + } + + // If other is a dict + if (PyDict_Check(other)) { + // equivalent to dict(self.items()) == other + int size = Length(self); + if (size != PyDict_Size(other)) { + return false; + } + for (int index = 0; index < size; index++) { + ScopedPyObjectPtr key(_NewKey_ByIndex(self, index)); + if (key == NULL) { + return -1; + } + ScopedPyObjectPtr value1(_NewObj_ByIndex(self, index)); + if (value1 == NULL) { + return -1; + } + PyObject* value2 = PyDict_GetItem(other, key.get()); + if (value2 == NULL) { + // Not found in the other dictionary + return 0; + } + int cmp = PyObject_RichCompareBool(value1.get(), value2, Py_EQ); + if (cmp != 1) // error or not equal + return cmp; + } + // All items were found and equal + return 1; + } + + // Any other object is different. + return 0; +} + +static PyObject* RichCompare(PyContainer* self, PyObject* other, int opid) { + if (opid != Py_EQ && opid != Py_NE) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + + int result; + + if (self->kind == PyContainer::KIND_SEQUENCE) { + result = DescriptorSequence_Equal(self, other); + } else { + result = DescriptorMapping_Equal(self, other); + } + if (result < 0) { + return NULL; + } + if (result ^ (opid == Py_NE)) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + +static PySequenceMethods MappingSequenceMethods = { + 0, // sq_length + 0, // sq_concat + 0, // sq_repeat + 0, // sq_item + 0, // sq_slice + 0, // sq_ass_item + 0, // sq_ass_slice + (objobjproc)Contains, // sq_contains +}; + +static PyObject* Get(PyContainer* self, PyObject* args) { + PyObject* key; + PyObject* default_value = Py_None; + if (!PyArg_UnpackTuple(args, "get", 1, 2, &key, &default_value)) { + return NULL; + } + + const void* item; + if (!_GetItemByKey(self, key, &item)) { + return NULL; + } + if (item == NULL) { + Py_INCREF(default_value); + return default_value; + } + return self->container_def->new_object_from_item_fn(item); +} + +static PyObject* Keys(PyContainer* self, PyObject* args) { + Py_ssize_t count = Length(self); + ScopedPyObjectPtr list(PyList_New(count)); + if (list == NULL) { + return NULL; + } + for (Py_ssize_t index = 0; index < count; ++index) { + PyObject* key = _NewKey_ByIndex(self, index); + if (key == NULL) { + return NULL; + } + PyList_SET_ITEM(list.get(), index, key); + } + return list.release(); +} + +static PyObject* Values(PyContainer* self, PyObject* args) { + Py_ssize_t count = Length(self); + ScopedPyObjectPtr list(PyList_New(count)); + if (list == NULL) { + return NULL; + } + for (Py_ssize_t index = 0; index < count; ++index) { + PyObject* value = _NewObj_ByIndex(self, index); + if (value == NULL) { + return NULL; + } + PyList_SET_ITEM(list.get(), index, value); + } + return list.release(); +} + +static PyObject* Items(PyContainer* self, PyObject* args) { + Py_ssize_t count = Length(self); + ScopedPyObjectPtr list(PyList_New(count)); + if (list == NULL) { + return NULL; + } + for (Py_ssize_t index = 0; index < count; ++index) { + ScopedPyObjectPtr obj(PyTuple_New(2)); + if (obj == NULL) { + return NULL; + } + PyObject* key = _NewKey_ByIndex(self, index); + if (key == NULL) { + return NULL; + } + PyTuple_SET_ITEM(obj.get(), 0, key); + PyObject* value = _NewObj_ByIndex(self, index); + if (value == NULL) { + return NULL; + } + PyTuple_SET_ITEM(obj.get(), 1, value); + PyList_SET_ITEM(list.get(), index, obj.release()); + } + return list.release(); +} + +static PyObject* NewContainerIterator(PyContainer* mapping, + PyContainerIterator::IterKind kind); + +static PyObject* Iter(PyContainer* self) { + return NewContainerIterator(self, PyContainerIterator::KIND_ITERKEY); +} +static PyObject* IterKeys(PyContainer* self, PyObject* args) { + return NewContainerIterator(self, PyContainerIterator::KIND_ITERKEY); +} +static PyObject* IterValues(PyContainer* self, PyObject* args) { + return NewContainerIterator(self, PyContainerIterator::KIND_ITERVALUE); +} +static PyObject* IterItems(PyContainer* self, PyObject* args) { + return NewContainerIterator(self, PyContainerIterator::KIND_ITERITEM); +} + +static PyMethodDef MappingMethods[] = { + { "get", (PyCFunction)Get, METH_VARARGS, }, + { "keys", (PyCFunction)Keys, METH_NOARGS, }, + { "values", (PyCFunction)Values, METH_NOARGS, }, + { "items", (PyCFunction)Items, METH_NOARGS, }, + { "iterkeys", (PyCFunction)IterKeys, METH_NOARGS, }, + { "itervalues", (PyCFunction)IterValues, METH_NOARGS, }, + { "iteritems", (PyCFunction)IterItems, METH_NOARGS, }, + {NULL} +}; + +PyTypeObject DescriptorMapping_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "DescriptorMapping", // tp_name + sizeof(PyContainer), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + (reprfunc)ContainerRepr, // tp_repr + 0, // tp_as_number + &MappingSequenceMethods, // tp_as_sequence + &MappingMappingMethods, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + 0, // tp_doc + 0, // tp_traverse + 0, // tp_clear + (richcmpfunc)RichCompare, // tp_richcompare + 0, // tp_weaklistoffset + (getiterfunc)Iter, // tp_iter + 0, // tp_iternext + MappingMethods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + 0, // tp_new + 0, // tp_free +}; + +// The DescriptorSequence type. + +static PyObject* GetItem(PyContainer* self, Py_ssize_t index) { + if (index < 0) { + index += Length(self); + } + if (index < 0 || index >= Length(self)) { + PyErr_SetString(PyExc_IndexError, "index out of range"); + return NULL; + } + return _NewObj_ByIndex(self, index); +} + +static PyObject * +SeqSubscript(PyContainer* self, PyObject* item) { + if (PyIndex_Check(item)) { + Py_ssize_t index; + index = PyNumber_AsSsize_t(item, PyExc_IndexError); + if (index == -1 && PyErr_Occurred()) + return NULL; + return GetItem(self, index); + } + // Materialize the list and delegate the operation to it. + ScopedPyObjectPtr list(PyObject_CallFunctionObjArgs( + reinterpret_cast<PyObject*>(&PyList_Type), self, NULL)); + if (list == NULL) { + return NULL; + } + return Py_TYPE(list.get())->tp_as_mapping->mp_subscript(list.get(), item); +} + +// Returns the position of the item in the sequence, of -1 if not found. +// This function never fails. +int Find(PyContainer* self, PyObject* item) { + // The item can only be in one position: item.index. + // Check that self[item.index] == item, it's faster than a linear search. + // + // This assumes that sequences are only defined by syntax of the .proto file: + // a specific item belongs to only one sequence, depending on its position in + // the .proto file definition. + const void* descriptor_ptr = PyDescriptor_AsVoidPtr(item); + if (descriptor_ptr == NULL) { + PyErr_Clear(); + // Not a descriptor, it cannot be in the list. + return -1; + } + if (self->container_def->get_item_index_fn) { + int index = self->container_def->get_item_index_fn(descriptor_ptr); + if (index < 0 || index >= Length(self)) { + // This index is not from this collection. + return -1; + } + if (self->container_def->get_by_index_fn(self, index) != descriptor_ptr) { + // The descriptor at this index is not the same. + return -1; + } + // self[item.index] == item, so return the index. + return index; + } else { + // Fall back to linear search. + int length = Length(self); + for (int index=0; index < length; index++) { + if (self->container_def->get_by_index_fn(self, index) == descriptor_ptr) { + return index; + } + } + // Not found + return -1; + } +} + +// Implements list.index(): the position of the item is in the sequence. +static PyObject* Index(PyContainer* self, PyObject* item) { + int position = Find(self, item); + if (position < 0) { + // Not found + PyErr_SetNone(PyExc_ValueError); + return NULL; + } else { + return PyInt_FromLong(position); + } +} +// Implements "list.__contains__()": is the object in the sequence. +static int SeqContains(PyContainer* self, PyObject* item) { + int position = Find(self, item); + if (position < 0) { + return 0; + } else { + return 1; + } +} + +// Implements list.count(): number of occurrences of the item in the sequence. +// An item can only appear once in a sequence. If it exists, return 1. +static PyObject* Count(PyContainer* self, PyObject* item) { + int position = Find(self, item); + if (position < 0) { + return PyInt_FromLong(0); + } else { + return PyInt_FromLong(1); + } +} + +static PyObject* Append(PyContainer* self, PyObject* args) { + if (_CalledFromGeneratedFile(0)) { + Py_RETURN_NONE; + } + PyErr_Format(PyExc_TypeError, + "'%.200s' object is not a mutable sequence", + Py_TYPE(self)->tp_name); + return NULL; +} + +static PyObject* Reversed(PyContainer* self, PyObject* args) { + return NewContainerIterator(self, + PyContainerIterator::KIND_ITERVALUE_REVERSED); +} + +static PyMethodDef SeqMethods[] = { + { "index", (PyCFunction)Index, METH_O, }, + { "count", (PyCFunction)Count, METH_O, }, + { "append", (PyCFunction)Append, METH_O, }, + { "__reversed__", (PyCFunction)Reversed, METH_NOARGS, }, + {NULL} +}; + +static PySequenceMethods SeqSequenceMethods = { + (lenfunc)Length, // sq_length + 0, // sq_concat + 0, // sq_repeat + (ssizeargfunc)GetItem, // sq_item + 0, // sq_slice + 0, // sq_ass_item + 0, // sq_ass_slice + (objobjproc)SeqContains, // sq_contains +}; + +static PyMappingMethods SeqMappingMethods = { + (lenfunc)Length, // mp_length + (binaryfunc)SeqSubscript, // mp_subscript + 0, // mp_ass_subscript +}; + +PyTypeObject DescriptorSequence_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "DescriptorSequence", // tp_name + sizeof(PyContainer), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + (reprfunc)ContainerRepr, // tp_repr + 0, // tp_as_number + &SeqSequenceMethods, // tp_as_sequence + &SeqMappingMethods, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + 0, // tp_doc + 0, // tp_traverse + 0, // tp_clear + (richcmpfunc)RichCompare, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + SeqMethods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + 0, // tp_new + 0, // tp_free +}; + +static PyObject* NewMappingByName( + DescriptorContainerDef* container_def, const void* descriptor) { + PyContainer* self = PyObject_New(PyContainer, &DescriptorMapping_Type); + if (self == NULL) { + return NULL; + } + self->descriptor = descriptor; + self->container_def = container_def; + self->kind = PyContainer::KIND_BYNAME; + return reinterpret_cast<PyObject*>(self); +} + +static PyObject* NewMappingByCamelcaseName( + DescriptorContainerDef* container_def, const void* descriptor) { + PyContainer* self = PyObject_New(PyContainer, &DescriptorMapping_Type); + if (self == NULL) { + return NULL; + } + self->descriptor = descriptor; + self->container_def = container_def; + self->kind = PyContainer::KIND_BYCAMELCASENAME; + return reinterpret_cast<PyObject*>(self); +} + +static PyObject* NewMappingByNumber( + DescriptorContainerDef* container_def, const void* descriptor) { + if (container_def->get_by_number_fn == NULL || + container_def->get_item_number_fn == NULL) { + PyErr_SetNone(PyExc_NotImplementedError); + return NULL; + } + PyContainer* self = PyObject_New(PyContainer, &DescriptorMapping_Type); + if (self == NULL) { + return NULL; + } + self->descriptor = descriptor; + self->container_def = container_def; + self->kind = PyContainer::KIND_BYNUMBER; + return reinterpret_cast<PyObject*>(self); +} + +static PyObject* NewSequence( + DescriptorContainerDef* container_def, const void* descriptor) { + PyContainer* self = PyObject_New(PyContainer, &DescriptorSequence_Type); + if (self == NULL) { + return NULL; + } + self->descriptor = descriptor; + self->container_def = container_def; + self->kind = PyContainer::KIND_SEQUENCE; + return reinterpret_cast<PyObject*>(self); +} + +// Implement iterators over PyContainers. + +static void Iterator_Dealloc(PyContainerIterator* self) { + Py_CLEAR(self->container); + Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); +} + +static PyObject* Iterator_Next(PyContainerIterator* self) { + int count = self->container->container_def->count_fn(self->container); + if (self->index >= count) { + // Return NULL with no exception to indicate the end. + return NULL; + } + int index = self->index; + self->index += 1; + switch (self->kind) { + case PyContainerIterator::KIND_ITERKEY: + return _NewKey_ByIndex(self->container, index); + case PyContainerIterator::KIND_ITERVALUE: + return _NewObj_ByIndex(self->container, index); + case PyContainerIterator::KIND_ITERVALUE_REVERSED: + return _NewObj_ByIndex(self->container, count - index - 1); + case PyContainerIterator::KIND_ITERITEM: + { + PyObject* obj = PyTuple_New(2); + if (obj == NULL) { + return NULL; + } + PyObject* key = _NewKey_ByIndex(self->container, index); + if (key == NULL) { + Py_DECREF(obj); + return NULL; + } + PyTuple_SET_ITEM(obj, 0, key); + PyObject* value = _NewObj_ByIndex(self->container, index); + if (value == NULL) { + Py_DECREF(obj); + return NULL; + } + PyTuple_SET_ITEM(obj, 1, value); + return obj; + } + default: + PyErr_SetNone(PyExc_NotImplementedError); + return NULL; + } +} + +static PyTypeObject ContainerIterator_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "DescriptorContainerIterator", // tp_name + sizeof(PyContainerIterator), // tp_basicsize + 0, // tp_itemsize + (destructor)Iterator_Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + 0, // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + PyObject_SelfIter, // tp_iter + (iternextfunc)Iterator_Next, // tp_iternext + 0, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + 0, // tp_new + 0, // tp_free +}; + +static PyObject* NewContainerIterator(PyContainer* container, + PyContainerIterator::IterKind kind) { + PyContainerIterator* self = PyObject_New(PyContainerIterator, + &ContainerIterator_Type); + if (self == NULL) { + return NULL; + } + Py_INCREF(container); + self->container = container; + self->kind = kind; + self->index = 0; + + return reinterpret_cast<PyObject*>(self); +} + +} // namespace descriptor + +// Now define the real collections! + +namespace message_descriptor { + +typedef const Descriptor* ParentDescriptor; + +static ParentDescriptor GetDescriptor(PyContainer* self) { + return reinterpret_cast<ParentDescriptor>(self->descriptor); +} + +namespace fields { + +typedef const FieldDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->field_count(); +} + +static const void* GetByName(PyContainer* self, ConstStringParam name) { + return GetDescriptor(self)->FindFieldByName(name); +} + +static const void* GetByCamelcaseName(PyContainer* self, + ConstStringParam name) { + return GetDescriptor(self)->FindFieldByCamelcaseName(name); +} + +static const void* GetByNumber(PyContainer* self, int number) { + return GetDescriptor(self)->FindFieldByNumber(number); +} + +static const void* GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->field(index); +} + +static PyObject* NewObjectFromItem(const void* item) { + return PyFieldDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item)); +} + +static const TProtoStringType& GetItemName(const void* item) { + return static_cast<ItemDescriptor>(item)->name(); +} + +static const TProtoStringType& GetItemCamelcaseName(const void* item) { + return static_cast<ItemDescriptor>(item)->camelcase_name(); +} + +static int GetItemNumber(const void* item) { + return static_cast<ItemDescriptor>(item)->number(); +} + +static int GetItemIndex(const void* item) { + return static_cast<ItemDescriptor>(item)->index(); +} + +static DescriptorContainerDef ContainerDef = { + "MessageFields", + Count, + GetByIndex, + GetByName, + GetByCamelcaseName, + GetByNumber, + NewObjectFromItem, + GetItemName, + GetItemCamelcaseName, + GetItemNumber, + GetItemIndex, +}; + +} // namespace fields + +PyObject* NewMessageFieldsByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&fields::ContainerDef, descriptor); +} + +PyObject* NewMessageFieldsByCamelcaseName(ParentDescriptor descriptor) { + return descriptor::NewMappingByCamelcaseName(&fields::ContainerDef, + descriptor); +} + +PyObject* NewMessageFieldsByNumber(ParentDescriptor descriptor) { + return descriptor::NewMappingByNumber(&fields::ContainerDef, descriptor); +} + +PyObject* NewMessageFieldsSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&fields::ContainerDef, descriptor); +} + +namespace nested_types { + +typedef const Descriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->nested_type_count(); +} + +static const void* GetByName(PyContainer* self, ConstStringParam name) { + return GetDescriptor(self)->FindNestedTypeByName(name); +} + +static const void* GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->nested_type(index); +} + +static PyObject* NewObjectFromItem(const void* item) { + return PyMessageDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item)); +} + +static const TProtoStringType& GetItemName(const void* item) { + return static_cast<ItemDescriptor>(item)->name(); +} + +static int GetItemIndex(const void* item) { + return static_cast<ItemDescriptor>(item)->index(); +} + +static DescriptorContainerDef ContainerDef = { + "MessageNestedTypes", + Count, + GetByIndex, + GetByName, + NULL, + NULL, + NewObjectFromItem, + GetItemName, + NULL, + NULL, + GetItemIndex, +}; + +} // namespace nested_types + +PyObject* NewMessageNestedTypesSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&nested_types::ContainerDef, descriptor); +} + +PyObject* NewMessageNestedTypesByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&nested_types::ContainerDef, descriptor); +} + +namespace enums { + +typedef const EnumDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->enum_type_count(); +} + +static const void* GetByName(PyContainer* self, ConstStringParam name) { + return GetDescriptor(self)->FindEnumTypeByName(name); +} + +static const void* GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->enum_type(index); +} + +static PyObject* NewObjectFromItem(const void* item) { + return PyEnumDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item)); +} + +static const TProtoStringType& GetItemName(const void* item) { + return static_cast<ItemDescriptor>(item)->name(); +} + +static int GetItemIndex(const void* item) { + return static_cast<ItemDescriptor>(item)->index(); +} + +static DescriptorContainerDef ContainerDef = { + "MessageNestedEnums", + Count, + GetByIndex, + GetByName, + NULL, + NULL, + NewObjectFromItem, + GetItemName, + NULL, + NULL, + GetItemIndex, +}; + +} // namespace enums + +PyObject* NewMessageEnumsByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&enums::ContainerDef, descriptor); +} + +PyObject* NewMessageEnumsSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&enums::ContainerDef, descriptor); +} + +namespace enumvalues { + +// This is the "enum_values_by_name" mapping, which collects values from all +// enum types in a message. +// +// Note that the behavior of the C++ descriptor is different: it will search and +// return the first value that matches the name, whereas the Python +// implementation retrieves the last one. + +typedef const EnumValueDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + int count = 0; + for (int i = 0; i < GetDescriptor(self)->enum_type_count(); ++i) { + count += GetDescriptor(self)->enum_type(i)->value_count(); + } + return count; +} + +static const void* GetByName(PyContainer* self, ConstStringParam name) { + return GetDescriptor(self)->FindEnumValueByName(name); +} + +static const void* GetByIndex(PyContainer* self, int index) { + // This is not optimal, but the number of enums *types* in a given message + // is small. This function is only used when iterating over the mapping. + const EnumDescriptor* enum_type = NULL; + int enum_type_count = GetDescriptor(self)->enum_type_count(); + for (int i = 0; i < enum_type_count; ++i) { + enum_type = GetDescriptor(self)->enum_type(i); + int enum_value_count = enum_type->value_count(); + if (index < enum_value_count) { + // Found it! + break; + } + index -= enum_value_count; + } + // The next statement cannot overflow, because this function is only called by + // internal iterators which ensure that 0 <= index < Count(). + return enum_type->value(index); +} + +static PyObject* NewObjectFromItem(const void* item) { + return PyEnumValueDescriptor_FromDescriptor( + static_cast<ItemDescriptor>(item)); +} + +static const TProtoStringType& GetItemName(const void* item) { + return static_cast<ItemDescriptor>(item)->name(); +} + +static DescriptorContainerDef ContainerDef = { + "MessageEnumValues", + Count, + GetByIndex, + GetByName, + NULL, + NULL, + NewObjectFromItem, + GetItemName, + NULL, + NULL, + NULL, +}; + +} // namespace enumvalues + +PyObject* NewMessageEnumValuesByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&enumvalues::ContainerDef, descriptor); +} + +namespace extensions { + +typedef const FieldDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->extension_count(); +} + +static const void* GetByName(PyContainer* self, ConstStringParam name) { + return GetDescriptor(self)->FindExtensionByName(name); +} + +static const void* GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->extension(index); +} + +static PyObject* NewObjectFromItem(const void* item) { + return PyFieldDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item)); +} + +static const TProtoStringType& GetItemName(const void* item) { + return static_cast<ItemDescriptor>(item)->name(); +} + +static int GetItemIndex(const void* item) { + return static_cast<ItemDescriptor>(item)->index(); +} + +static DescriptorContainerDef ContainerDef = { + "MessageExtensions", + Count, + GetByIndex, + GetByName, + NULL, + NULL, + NewObjectFromItem, + GetItemName, + NULL, + NULL, + GetItemIndex, +}; + +} // namespace extensions + +PyObject* NewMessageExtensionsByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&extensions::ContainerDef, descriptor); +} + +PyObject* NewMessageExtensionsSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&extensions::ContainerDef, descriptor); +} + +namespace oneofs { + +typedef const OneofDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->oneof_decl_count(); +} + +static const void* GetByName(PyContainer* self, ConstStringParam name) { + return GetDescriptor(self)->FindOneofByName(name); +} + +static const void* GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->oneof_decl(index); +} + +static PyObject* NewObjectFromItem(const void* item) { + return PyOneofDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item)); +} + +static const TProtoStringType& GetItemName(const void* item) { + return static_cast<ItemDescriptor>(item)->name(); +} + +static int GetItemIndex(const void* item) { + return static_cast<ItemDescriptor>(item)->index(); +} + +static DescriptorContainerDef ContainerDef = { + "MessageOneofs", + Count, + GetByIndex, + GetByName, + NULL, + NULL, + NewObjectFromItem, + GetItemName, + NULL, + NULL, + GetItemIndex, +}; + +} // namespace oneofs + +PyObject* NewMessageOneofsByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&oneofs::ContainerDef, descriptor); +} + +PyObject* NewMessageOneofsSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&oneofs::ContainerDef, descriptor); +} + +} // namespace message_descriptor + +namespace enum_descriptor { + +typedef const EnumDescriptor* ParentDescriptor; + +static ParentDescriptor GetDescriptor(PyContainer* self) { + return reinterpret_cast<ParentDescriptor>(self->descriptor); +} + +namespace enumvalues { + +typedef const EnumValueDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->value_count(); +} + +static const void* GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->value(index); +} + +static const void* GetByName(PyContainer* self, ConstStringParam name) { + return GetDescriptor(self)->FindValueByName(name); +} + +static const void* GetByNumber(PyContainer* self, int number) { + return GetDescriptor(self)->FindValueByNumber(number); +} + +static PyObject* NewObjectFromItem(const void* item) { + return PyEnumValueDescriptor_FromDescriptor( + static_cast<ItemDescriptor>(item)); +} + +static const TProtoStringType& GetItemName(const void* item) { + return static_cast<ItemDescriptor>(item)->name(); +} + +static int GetItemNumber(const void* item) { + return static_cast<ItemDescriptor>(item)->number(); +} + +static int GetItemIndex(const void* item) { + return static_cast<ItemDescriptor>(item)->index(); +} + +static DescriptorContainerDef ContainerDef = { + "EnumValues", + Count, + GetByIndex, + GetByName, + NULL, + GetByNumber, + NewObjectFromItem, + GetItemName, + NULL, + GetItemNumber, + GetItemIndex, +}; + +} // namespace enumvalues + +PyObject* NewEnumValuesByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&enumvalues::ContainerDef, descriptor); +} + +PyObject* NewEnumValuesByNumber(ParentDescriptor descriptor) { + return descriptor::NewMappingByNumber(&enumvalues::ContainerDef, descriptor); +} + +PyObject* NewEnumValuesSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&enumvalues::ContainerDef, descriptor); +} + +} // namespace enum_descriptor + +namespace oneof_descriptor { + +typedef const OneofDescriptor* ParentDescriptor; + +static ParentDescriptor GetDescriptor(PyContainer* self) { + return reinterpret_cast<ParentDescriptor>(self->descriptor); +} + +namespace fields { + +typedef const FieldDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->field_count(); +} + +static const void* GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->field(index); +} + +static PyObject* NewObjectFromItem(const void* item) { + return PyFieldDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item)); +} + +static int GetItemIndex(const void* item) { + return static_cast<ItemDescriptor>(item)->index_in_oneof(); +} + +static DescriptorContainerDef ContainerDef = { + "OneofFields", + Count, + GetByIndex, + NULL, + NULL, + NULL, + NewObjectFromItem, + NULL, + NULL, + NULL, + GetItemIndex, +}; + +} // namespace fields + +PyObject* NewOneofFieldsSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&fields::ContainerDef, descriptor); +} + +} // namespace oneof_descriptor + +namespace service_descriptor { + +typedef const ServiceDescriptor* ParentDescriptor; + +static ParentDescriptor GetDescriptor(PyContainer* self) { + return reinterpret_cast<ParentDescriptor>(self->descriptor); +} + +namespace methods { + +typedef const MethodDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->method_count(); +} + +static const void* GetByName(PyContainer* self, ConstStringParam name) { + return GetDescriptor(self)->FindMethodByName(name); +} + +static const void* GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->method(index); +} + +static PyObject* NewObjectFromItem(const void* item) { + return PyMethodDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item)); +} + +static const TProtoStringType& GetItemName(const void* item) { + return static_cast<ItemDescriptor>(item)->name(); +} + +static int GetItemIndex(const void* item) { + return static_cast<ItemDescriptor>(item)->index(); +} + +static DescriptorContainerDef ContainerDef = { + "ServiceMethods", + Count, + GetByIndex, + GetByName, + NULL, + NULL, + NewObjectFromItem, + GetItemName, + NULL, + NULL, + GetItemIndex, +}; + +} // namespace methods + +PyObject* NewServiceMethodsSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&methods::ContainerDef, descriptor); +} + +PyObject* NewServiceMethodsByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&methods::ContainerDef, descriptor); +} + +} // namespace service_descriptor + +namespace file_descriptor { + +typedef const FileDescriptor* ParentDescriptor; + +static ParentDescriptor GetDescriptor(PyContainer* self) { + return reinterpret_cast<ParentDescriptor>(self->descriptor); +} + +namespace messages { + +typedef const Descriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->message_type_count(); +} + +static const void* GetByName(PyContainer* self, ConstStringParam name) { + return GetDescriptor(self)->FindMessageTypeByName(name); +} + +static const void* GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->message_type(index); +} + +static PyObject* NewObjectFromItem(const void* item) { + return PyMessageDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item)); +} + +static const TProtoStringType& GetItemName(const void* item) { + return static_cast<ItemDescriptor>(item)->name(); +} + +static int GetItemIndex(const void* item) { + return static_cast<ItemDescriptor>(item)->index(); +} + +static DescriptorContainerDef ContainerDef = { + "FileMessages", + Count, + GetByIndex, + GetByName, + NULL, + NULL, + NewObjectFromItem, + GetItemName, + NULL, + NULL, + GetItemIndex, +}; + +} // namespace messages + +PyObject* NewFileMessageTypesByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&messages::ContainerDef, descriptor); +} + +namespace enums { + +typedef const EnumDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->enum_type_count(); +} + +static const void* GetByName(PyContainer* self, ConstStringParam name) { + return GetDescriptor(self)->FindEnumTypeByName(name); +} + +static const void* GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->enum_type(index); +} + +static PyObject* NewObjectFromItem(const void* item) { + return PyEnumDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item)); +} + +static const TProtoStringType& GetItemName(const void* item) { + return static_cast<ItemDescriptor>(item)->name(); +} + +static int GetItemIndex(const void* item) { + return static_cast<ItemDescriptor>(item)->index(); +} + +static DescriptorContainerDef ContainerDef = { + "FileEnums", + Count, + GetByIndex, + GetByName, + NULL, + NULL, + NewObjectFromItem, + GetItemName, + NULL, + NULL, + GetItemIndex, +}; + +} // namespace enums + +PyObject* NewFileEnumTypesByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&enums::ContainerDef, descriptor); +} + +namespace extensions { + +typedef const FieldDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->extension_count(); +} + +static const void* GetByName(PyContainer* self, ConstStringParam name) { + return GetDescriptor(self)->FindExtensionByName(name); +} + +static const void* GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->extension(index); +} + +static PyObject* NewObjectFromItem(const void* item) { + return PyFieldDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item)); +} + +static const TProtoStringType& GetItemName(const void* item) { + return static_cast<ItemDescriptor>(item)->name(); +} + +static int GetItemIndex(const void* item) { + return static_cast<ItemDescriptor>(item)->index(); +} + +static DescriptorContainerDef ContainerDef = { + "FileExtensions", + Count, + GetByIndex, + GetByName, + NULL, + NULL, + NewObjectFromItem, + GetItemName, + NULL, + NULL, + GetItemIndex, +}; + +} // namespace extensions + +PyObject* NewFileExtensionsByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&extensions::ContainerDef, descriptor); +} + +namespace services { + +typedef const ServiceDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->service_count(); +} + +static const void* GetByName(PyContainer* self, ConstStringParam name) { + return GetDescriptor(self)->FindServiceByName(name); +} + +static const void* GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->service(index); +} + +static PyObject* NewObjectFromItem(const void* item) { + return PyServiceDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item)); +} + +static const TProtoStringType& GetItemName(const void* item) { + return static_cast<ItemDescriptor>(item)->name(); +} + +static int GetItemIndex(const void* item) { + return static_cast<ItemDescriptor>(item)->index(); +} + +static DescriptorContainerDef ContainerDef = { + "FileServices", + Count, + GetByIndex, + GetByName, + NULL, + NULL, + NewObjectFromItem, + GetItemName, + NULL, + NULL, + GetItemIndex, +}; + +} // namespace services + +PyObject* NewFileServicesByName(const FileDescriptor* descriptor) { + return descriptor::NewMappingByName(&services::ContainerDef, descriptor); +} + +namespace dependencies { + +typedef const FileDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->dependency_count(); +} + +static const void* GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->dependency(index); +} + +static PyObject* NewObjectFromItem(const void* item) { + return PyFileDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item)); +} + +static DescriptorContainerDef ContainerDef = { + "FileDependencies", + Count, + GetByIndex, + NULL, + NULL, + NULL, + NewObjectFromItem, + NULL, + NULL, + NULL, + NULL, +}; + +} // namespace dependencies + +PyObject* NewFileDependencies(const FileDescriptor* descriptor) { + return descriptor::NewSequence(&dependencies::ContainerDef, descriptor); +} + +namespace public_dependencies { + +typedef const FileDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->public_dependency_count(); +} + +static const void* GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->public_dependency(index); +} + +static PyObject* NewObjectFromItem(const void* item) { + return PyFileDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item)); +} + +static DescriptorContainerDef ContainerDef = { + "FilePublicDependencies", + Count, + GetByIndex, + NULL, + NULL, + NULL, + NewObjectFromItem, + NULL, + NULL, + NULL, + NULL, +}; + +} // namespace public_dependencies + +PyObject* NewFilePublicDependencies(const FileDescriptor* descriptor) { + return descriptor::NewSequence(&public_dependencies::ContainerDef, + descriptor); +} + +} // namespace file_descriptor + + +// Register all implementations + +bool InitDescriptorMappingTypes() { + if (PyType_Ready(&descriptor::DescriptorMapping_Type) < 0) + return false; + if (PyType_Ready(&descriptor::DescriptorSequence_Type) < 0) + return false; + if (PyType_Ready(&descriptor::ContainerIterator_Type) < 0) + return false; + return true; +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor_containers.h b/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor_containers.h new file mode 100644 index 0000000000..4e05c58e2b --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor_containers.h @@ -0,0 +1,109 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_CONTAINERS_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_CONTAINERS_H__ + +// Mappings and Sequences of descriptors. +// They implement containers like fields_by_name, EnumDescriptor.values... +// See descriptor_containers.cc for more description. +#include <Python.h> + +namespace google { +namespace protobuf { + +class Descriptor; +class FileDescriptor; +class EnumDescriptor; +class OneofDescriptor; +class ServiceDescriptor; + +namespace python { + +// Initialize the various types and objects. +bool InitDescriptorMappingTypes(); + +// Each function below returns a Mapping, or a Sequence of descriptors. +// They all return a new reference. + +namespace message_descriptor { +PyObject* NewMessageFieldsByName(const Descriptor* descriptor); +PyObject* NewMessageFieldsByCamelcaseName(const Descriptor* descriptor); +PyObject* NewMessageFieldsByNumber(const Descriptor* descriptor); +PyObject* NewMessageFieldsSeq(const Descriptor* descriptor); + +PyObject* NewMessageNestedTypesSeq(const Descriptor* descriptor); +PyObject* NewMessageNestedTypesByName(const Descriptor* descriptor); + +PyObject* NewMessageEnumsByName(const Descriptor* descriptor); +PyObject* NewMessageEnumsSeq(const Descriptor* descriptor); +PyObject* NewMessageEnumValuesByName(const Descriptor* descriptor); + +PyObject* NewMessageExtensionsByName(const Descriptor* descriptor); +PyObject* NewMessageExtensionsSeq(const Descriptor* descriptor); + +PyObject* NewMessageOneofsByName(const Descriptor* descriptor); +PyObject* NewMessageOneofsSeq(const Descriptor* descriptor); +} // namespace message_descriptor + +namespace enum_descriptor { +PyObject* NewEnumValuesByName(const EnumDescriptor* descriptor); +PyObject* NewEnumValuesByNumber(const EnumDescriptor* descriptor); +PyObject* NewEnumValuesSeq(const EnumDescriptor* descriptor); +} // namespace enum_descriptor + +namespace oneof_descriptor { +PyObject* NewOneofFieldsSeq(const OneofDescriptor* descriptor); +} // namespace oneof_descriptor + +namespace file_descriptor { +PyObject* NewFileMessageTypesByName(const FileDescriptor* descriptor); + +PyObject* NewFileEnumTypesByName(const FileDescriptor* descriptor); + +PyObject* NewFileExtensionsByName(const FileDescriptor* descriptor); + +PyObject* NewFileServicesByName(const FileDescriptor* descriptor); + +PyObject* NewFileDependencies(const FileDescriptor* descriptor); +PyObject* NewFilePublicDependencies(const FileDescriptor* descriptor); +} // namespace file_descriptor + +namespace service_descriptor { +PyObject* NewServiceMethodsSeq(const ServiceDescriptor* descriptor); +PyObject* NewServiceMethodsByName(const ServiceDescriptor* descriptor); +} // namespace service_descriptor + + +} // namespace python +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_CONTAINERS_H__ diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor_database.cc b/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor_database.cc new file mode 100644 index 0000000000..be8089841c --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor_database.cc @@ -0,0 +1,187 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// This file defines a C++ DescriptorDatabase, which wraps a Python Database +// and delegate all its operations to Python methods. + +#include <google/protobuf/pyext/descriptor_database.h> + +#include <cstdint> + +#include <google/protobuf/stubs/logging.h> +#include <google/protobuf/stubs/common.h> +#include <google/protobuf/descriptor.pb.h> +#include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> + +namespace google { +namespace protobuf { +namespace python { + +PyDescriptorDatabase::PyDescriptorDatabase(PyObject* py_database) + : py_database_(py_database) { + Py_INCREF(py_database_); +} + +PyDescriptorDatabase::~PyDescriptorDatabase() { Py_DECREF(py_database_); } + +// Convert a Python object to a FileDescriptorProto pointer. +// Handles all kinds of Python errors, which are simply logged. +static bool GetFileDescriptorProto(PyObject* py_descriptor, + FileDescriptorProto* output) { + if (py_descriptor == NULL) { + if (PyErr_ExceptionMatches(PyExc_KeyError)) { + // Expected error: item was simply not found. + PyErr_Clear(); + } else { + GOOGLE_LOG(ERROR) << "DescriptorDatabase method raised an error"; + PyErr_Print(); + } + return false; + } + if (py_descriptor == Py_None) { + return false; + } + const Descriptor* filedescriptor_descriptor = + FileDescriptorProto::default_instance().GetDescriptor(); + CMessage* message = reinterpret_cast<CMessage*>(py_descriptor); + if (PyObject_TypeCheck(py_descriptor, CMessage_Type) && + message->message->GetDescriptor() == filedescriptor_descriptor) { + // Fast path: Just use the pointer. + FileDescriptorProto* file_proto = + static_cast<FileDescriptorProto*>(message->message); + *output = *file_proto; + return true; + } else { + // Slow path: serialize the message. This allows to use databases which + // use a different implementation of FileDescriptorProto. + ScopedPyObjectPtr serialized_pb( + PyObject_CallMethod(py_descriptor, "SerializeToString", NULL)); + if (serialized_pb == NULL) { + GOOGLE_LOG(ERROR) + << "DescriptorDatabase method did not return a FileDescriptorProto"; + PyErr_Print(); + return false; + } + char* str; + Py_ssize_t len; + if (PyBytes_AsStringAndSize(serialized_pb.get(), &str, &len) < 0) { + GOOGLE_LOG(ERROR) + << "DescriptorDatabase method did not return a FileDescriptorProto"; + PyErr_Print(); + return false; + } + FileDescriptorProto file_proto; + if (!file_proto.ParseFromArray(str, len)) { + GOOGLE_LOG(ERROR) + << "DescriptorDatabase method did not return a FileDescriptorProto"; + return false; + } + *output = file_proto; + return true; + } +} + +// Find a file by file name. +bool PyDescriptorDatabase::FindFileByName(const TProtoStringType& filename, + FileDescriptorProto* output) { + ScopedPyObjectPtr py_descriptor(PyObject_CallMethod( + py_database_, "FindFileByName", "s#", filename.c_str(), filename.size())); + return GetFileDescriptorProto(py_descriptor.get(), output); +} + +// Find the file that declares the given fully-qualified symbol name. +bool PyDescriptorDatabase::FindFileContainingSymbol( + const TProtoStringType& symbol_name, FileDescriptorProto* output) { + ScopedPyObjectPtr py_descriptor( + PyObject_CallMethod(py_database_, "FindFileContainingSymbol", "s#", + symbol_name.c_str(), symbol_name.size())); + return GetFileDescriptorProto(py_descriptor.get(), output); +} + +// Find the file which defines an extension extending the given message type +// with the given field number. +// Python DescriptorDatabases are not required to implement this method. +bool PyDescriptorDatabase::FindFileContainingExtension( + const TProtoStringType& containing_type, int field_number, + FileDescriptorProto* output) { + ScopedPyObjectPtr py_method( + PyObject_GetAttrString(py_database_, "FindFileContainingExtension")); + if (py_method == NULL) { + // This method is not implemented, returns without error. + PyErr_Clear(); + return false; + } + ScopedPyObjectPtr py_descriptor( + PyObject_CallFunction(py_method.get(), "s#i", containing_type.c_str(), + containing_type.size(), field_number)); + return GetFileDescriptorProto(py_descriptor.get(), output); +} + +// Finds the tag numbers used by all known extensions of +// containing_type, and appends them to output in an undefined +// order. +// Python DescriptorDatabases are not required to implement this method. +bool PyDescriptorDatabase::FindAllExtensionNumbers( + const TProtoStringType& containing_type, std::vector<int>* output) { + ScopedPyObjectPtr py_method( + PyObject_GetAttrString(py_database_, "FindAllExtensionNumbers")); + if (py_method == NULL) { + // This method is not implemented, returns without error. + PyErr_Clear(); + return false; + } + ScopedPyObjectPtr py_list( + PyObject_CallFunction(py_method.get(), "s#", containing_type.c_str(), + containing_type.size())); + if (py_list == NULL) { + PyErr_Print(); + return false; + } + Py_ssize_t size = PyList_Size(py_list.get()); + int64_t item_value; + for (Py_ssize_t i = 0 ; i < size; ++i) { + ScopedPyObjectPtr item(PySequence_GetItem(py_list.get(), i)); + item_value = PyLong_AsLong(item.get()); + if (item_value < 0) { + GOOGLE_LOG(ERROR) + << "FindAllExtensionNumbers method did not return " + << "valid extension numbers."; + PyErr_Print(); + return false; + } + output->push_back(item_value); + } + return true; +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor_database.h b/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor_database.h new file mode 100644 index 0000000000..59918a6d92 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor_database.h @@ -0,0 +1,81 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_DATABASE_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_DATABASE_H__ + +#include <Python.h> + +#include <google/protobuf/descriptor_database.h> + +namespace google { +namespace protobuf { +namespace python { + +class PyDescriptorDatabase : public DescriptorDatabase { + public: + explicit PyDescriptorDatabase(PyObject* py_database); + ~PyDescriptorDatabase(); + + // Implement the abstract interface. All these functions fill the output + // with a copy of FileDescriptorProto. + + // Find a file by file name. + bool FindFileByName(const TProtoStringType& filename, FileDescriptorProto* output); + + // Find the file that declares the given fully-qualified symbol name. + bool FindFileContainingSymbol(const TProtoStringType& symbol_name, + FileDescriptorProto* output); + + // Find the file which defines an extension extending the given message type + // with the given field number. + // Containing_type must be a fully-qualified type name. + // Python objects are not required to implement this method. + bool FindFileContainingExtension(const TProtoStringType& containing_type, + int field_number, + FileDescriptorProto* output); + + // Finds the tag numbers used by all known extensions of + // containing_type, and appends them to output in an undefined + // order. + // Python objects are not required to implement this method. + bool FindAllExtensionNumbers(const TProtoStringType& containing_type, + std::vector<int>* output); + + private: + // The python object that implements the database. The reference is owned. + PyObject* py_database_; +}; + +} // namespace python +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_DATABASE_H__ diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor_pool.cc b/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor_pool.cc new file mode 100644 index 0000000000..a53411e797 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor_pool.cc @@ -0,0 +1,773 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Implements the DescriptorPool, which collects all descriptors. + +#include <unordered_map> + +#include <Python.h> + +#include <google/protobuf/descriptor.pb.h> +#include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/descriptor_database.h> +#include <google/protobuf/pyext/descriptor_pool.h> +#include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/message_factory.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> +#include <google/protobuf/stubs/hash.h> + +#if PY_MAJOR_VERSION >= 3 + #define PyString_FromStringAndSize PyUnicode_FromStringAndSize + #if PY_VERSION_HEX < 0x03030000 + #error "Python 3.0 - 3.2 are not supported." + #endif +#define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>( \ + PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \ + ? -1 \ + : 0) \ + : PyBytes_AsStringAndSize(ob, (charpp), (sizep))) +#endif + +namespace google { +namespace protobuf { +namespace python { + +// A map to cache Python Pools per C++ pointer. +// Pointers are not owned here, and belong to the PyDescriptorPool. +static std::unordered_map<const DescriptorPool*, PyDescriptorPool*>* + descriptor_pool_map; + +namespace cdescriptor_pool { + +// Collects errors that occur during proto file building to allow them to be +// propagated in the python exception instead of only living in ERROR logs. +class BuildFileErrorCollector : public DescriptorPool::ErrorCollector { + public: + BuildFileErrorCollector() : error_message(""), had_errors_(false) {} + + void AddError(const TProtoStringType& filename, const TProtoStringType& element_name, + const Message* descriptor, ErrorLocation location, + const TProtoStringType& message) override { + // Replicates the logging behavior that happens in the C++ implementation + // when an error collector is not passed in. + if (!had_errors_) { + error_message += + ("Invalid proto descriptor for file \"" + filename + "\":\n"); + had_errors_ = true; + } + // As this only happens on failure and will result in the program not + // running at all, no effort is made to optimize this string manipulation. + error_message += (" " + element_name + ": " + message + "\n"); + } + + void Clear() { + had_errors_ = false; + error_message = ""; + } + + TProtoStringType error_message; + + private: + bool had_errors_; +}; + +// Create a Python DescriptorPool object, but does not fill the "pool" +// attribute. +static PyDescriptorPool* _CreateDescriptorPool() { + PyDescriptorPool* cpool = PyObject_GC_New( + PyDescriptorPool, &PyDescriptorPool_Type); + if (cpool == NULL) { + return NULL; + } + + cpool->error_collector = nullptr; + cpool->underlay = NULL; + cpool->database = NULL; + + cpool->descriptor_options = new std::unordered_map<const void*, PyObject*>(); + + cpool->py_message_factory = message_factory::NewMessageFactory( + &PyMessageFactory_Type, cpool); + if (cpool->py_message_factory == NULL) { + Py_DECREF(cpool); + return NULL; + } + + PyObject_GC_Track(cpool); + + return cpool; +} + +// Create a Python DescriptorPool, using the given pool as an underlay: +// new messages will be added to a custom pool, not to the underlay. +// +// Ownership of the underlay is not transferred, its pointer should +// stay alive. +static PyDescriptorPool* PyDescriptorPool_NewWithUnderlay( + const DescriptorPool* underlay) { + PyDescriptorPool* cpool = _CreateDescriptorPool(); + if (cpool == NULL) { + return NULL; + } + cpool->pool = new DescriptorPool(underlay); + cpool->underlay = underlay; + + if (!descriptor_pool_map->insert( + std::make_pair(cpool->pool, cpool)).second) { + // Should never happen -- would indicate an internal error / bug. + PyErr_SetString(PyExc_ValueError, "DescriptorPool already registered"); + return NULL; + } + + return cpool; +} + +static PyDescriptorPool* PyDescriptorPool_NewWithDatabase( + DescriptorDatabase* database) { + PyDescriptorPool* cpool = _CreateDescriptorPool(); + if (cpool == NULL) { + return NULL; + } + if (database != NULL) { + cpool->error_collector = new BuildFileErrorCollector(); + cpool->pool = new DescriptorPool(database, cpool->error_collector); + cpool->database = database; + } else { + cpool->pool = new DescriptorPool(); + } + + if (!descriptor_pool_map->insert(std::make_pair(cpool->pool, cpool)).second) { + // Should never happen -- would indicate an internal error / bug. + PyErr_SetString(PyExc_ValueError, "DescriptorPool already registered"); + return NULL; + } + + return cpool; +} + +// The public DescriptorPool constructor. +static PyObject* New(PyTypeObject* type, + PyObject* args, PyObject* kwargs) { + static const char* kwlist[] = {"descriptor_db", 0}; + PyObject* py_database = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", + const_cast<char**>(kwlist), &py_database)) { + return NULL; + } + DescriptorDatabase* database = NULL; + if (py_database && py_database != Py_None) { + database = new PyDescriptorDatabase(py_database); + } + return reinterpret_cast<PyObject*>( + PyDescriptorPool_NewWithDatabase(database)); +} + +static void Dealloc(PyObject* pself) { + PyDescriptorPool* self = reinterpret_cast<PyDescriptorPool*>(pself); + descriptor_pool_map->erase(self->pool); + Py_CLEAR(self->py_message_factory); + for (std::unordered_map<const void*, PyObject*>::iterator it = + self->descriptor_options->begin(); + it != self->descriptor_options->end(); ++it) { + Py_DECREF(it->second); + } + delete self->descriptor_options; + delete self->database; + delete self->pool; + delete self->error_collector; + Py_TYPE(self)->tp_free(pself); +} + +static int GcTraverse(PyObject* pself, visitproc visit, void* arg) { + PyDescriptorPool* self = reinterpret_cast<PyDescriptorPool*>(pself); + Py_VISIT(self->py_message_factory); + return 0; +} + +static int GcClear(PyObject* pself) { + PyDescriptorPool* self = reinterpret_cast<PyDescriptorPool*>(pself); + Py_CLEAR(self->py_message_factory); + return 0; +} + +PyObject* SetErrorFromCollector(DescriptorPool::ErrorCollector* self, + const char* name, const char* error_type) { + BuildFileErrorCollector* error_collector = + reinterpret_cast<BuildFileErrorCollector*>(self); + if (error_collector && !error_collector->error_message.empty()) { + PyErr_Format(PyExc_KeyError, "Couldn't build file for %s %.200s\n%s", + error_type, name, error_collector->error_message.c_str()); + error_collector->Clear(); + return NULL; + } + PyErr_Format(PyExc_KeyError, "Couldn't find %s %.200s", error_type, name); + return NULL; +} + +static PyObject* FindMessageByName(PyObject* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const Descriptor* message_descriptor = + reinterpret_cast<PyDescriptorPool*>(self)->pool->FindMessageTypeByName( + StringParam(name, name_size)); + + if (message_descriptor == NULL) { + return SetErrorFromCollector( + reinterpret_cast<PyDescriptorPool*>(self)->error_collector, name, + "message"); + } + + + return PyMessageDescriptor_FromDescriptor(message_descriptor); +} + + + + +static PyObject* FindFileByName(PyObject* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + PyDescriptorPool* py_pool = reinterpret_cast<PyDescriptorPool*>(self); + const FileDescriptor* file_descriptor = + py_pool->pool->FindFileByName(StringParam(name, name_size)); + + if (file_descriptor == NULL) { + return SetErrorFromCollector(py_pool->error_collector, name, "file"); + } + return PyFileDescriptor_FromDescriptor(file_descriptor); +} + +PyObject* FindFieldByName(PyDescriptorPool* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const FieldDescriptor* field_descriptor = + self->pool->FindFieldByName(StringParam(name, name_size)); + if (field_descriptor == NULL) { + return SetErrorFromCollector(self->error_collector, name, "field"); + } + + + return PyFieldDescriptor_FromDescriptor(field_descriptor); +} + +static PyObject* FindFieldByNameMethod(PyObject* self, PyObject* arg) { + return FindFieldByName(reinterpret_cast<PyDescriptorPool*>(self), arg); +} + +PyObject* FindExtensionByName(PyDescriptorPool* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const FieldDescriptor* field_descriptor = + self->pool->FindExtensionByName(StringParam(name, name_size)); + if (field_descriptor == NULL) { + return SetErrorFromCollector(self->error_collector, name, + "extension field"); + } + + + return PyFieldDescriptor_FromDescriptor(field_descriptor); +} + +static PyObject* FindExtensionByNameMethod(PyObject* self, PyObject* arg) { + return FindExtensionByName(reinterpret_cast<PyDescriptorPool*>(self), arg); +} + +PyObject* FindEnumTypeByName(PyDescriptorPool* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const EnumDescriptor* enum_descriptor = + self->pool->FindEnumTypeByName(StringParam(name, name_size)); + if (enum_descriptor == NULL) { + return SetErrorFromCollector(self->error_collector, name, "enum"); + } + + + return PyEnumDescriptor_FromDescriptor(enum_descriptor); +} + +static PyObject* FindEnumTypeByNameMethod(PyObject* self, PyObject* arg) { + return FindEnumTypeByName(reinterpret_cast<PyDescriptorPool*>(self), arg); +} + +PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const OneofDescriptor* oneof_descriptor = + self->pool->FindOneofByName(StringParam(name, name_size)); + if (oneof_descriptor == NULL) { + return SetErrorFromCollector(self->error_collector, name, "oneof"); + } + + + return PyOneofDescriptor_FromDescriptor(oneof_descriptor); +} + +static PyObject* FindOneofByNameMethod(PyObject* self, PyObject* arg) { + return FindOneofByName(reinterpret_cast<PyDescriptorPool*>(self), arg); +} + +static PyObject* FindServiceByName(PyObject* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const ServiceDescriptor* service_descriptor = + reinterpret_cast<PyDescriptorPool*>(self)->pool->FindServiceByName( + StringParam(name, name_size)); + if (service_descriptor == NULL) { + return SetErrorFromCollector( + reinterpret_cast<PyDescriptorPool*>(self)->error_collector, name, + "service"); + } + + + return PyServiceDescriptor_FromDescriptor(service_descriptor); +} + +static PyObject* FindMethodByName(PyObject* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const MethodDescriptor* method_descriptor = + reinterpret_cast<PyDescriptorPool*>(self)->pool->FindMethodByName( + StringParam(name, name_size)); + if (method_descriptor == NULL) { + return SetErrorFromCollector( + reinterpret_cast<PyDescriptorPool*>(self)->error_collector, name, + "method"); + } + + + return PyMethodDescriptor_FromDescriptor(method_descriptor); +} + +static PyObject* FindFileContainingSymbol(PyObject* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const FileDescriptor* file_descriptor = + reinterpret_cast<PyDescriptorPool*>(self)->pool->FindFileContainingSymbol( + StringParam(name, name_size)); + if (file_descriptor == NULL) { + return SetErrorFromCollector( + reinterpret_cast<PyDescriptorPool*>(self)->error_collector, name, + "symbol"); + } + + + return PyFileDescriptor_FromDescriptor(file_descriptor); +} + +static PyObject* FindExtensionByNumber(PyObject* self, PyObject* args) { + PyObject* message_descriptor; + int number; + if (!PyArg_ParseTuple(args, "Oi", &message_descriptor, &number)) { + return NULL; + } + const Descriptor* descriptor = PyMessageDescriptor_AsDescriptor( + message_descriptor); + if (descriptor == NULL) { + return NULL; + } + + const FieldDescriptor* extension_descriptor = + reinterpret_cast<PyDescriptorPool*>(self)->pool->FindExtensionByNumber( + descriptor, number); + if (extension_descriptor == NULL) { + BuildFileErrorCollector* error_collector = + reinterpret_cast<BuildFileErrorCollector*>( + reinterpret_cast<PyDescriptorPool*>(self)->error_collector); + if (error_collector && !error_collector->error_message.empty()) { + PyErr_Format(PyExc_KeyError, "Couldn't build file for Extension %.d\n%s", + number, error_collector->error_message.c_str()); + error_collector->Clear(); + return NULL; + } + PyErr_Format(PyExc_KeyError, "Couldn't find Extension %d", number); + return NULL; + } + + + return PyFieldDescriptor_FromDescriptor(extension_descriptor); +} + +static PyObject* FindAllExtensions(PyObject* self, PyObject* arg) { + const Descriptor* descriptor = PyMessageDescriptor_AsDescriptor(arg); + if (descriptor == NULL) { + return NULL; + } + + std::vector<const FieldDescriptor*> extensions; + reinterpret_cast<PyDescriptorPool*>(self)->pool->FindAllExtensions( + descriptor, &extensions); + + ScopedPyObjectPtr result(PyList_New(extensions.size())); + if (result == NULL) { + return NULL; + } + for (int i = 0; i < extensions.size(); i++) { + PyObject* extension = PyFieldDescriptor_FromDescriptor(extensions[i]); + if (extension == NULL) { + return NULL; + } + PyList_SET_ITEM(result.get(), i, extension); // Steals the reference. + } + return result.release(); +} + +// These functions should not exist -- the only valid way to create +// descriptors is to call Add() or AddSerializedFile(). +// But these AddDescriptor() functions were created in Python and some people +// call them, so we support them for now for compatibility. +// However we do check that the existing descriptor already exists in the pool, +// which appears to always be true for existing calls -- but then why do people +// call a function that will just be a no-op? +// TODO(amauryfa): Need to investigate further. + +static PyObject* AddFileDescriptor(PyObject* self, PyObject* descriptor) { + const FileDescriptor* file_descriptor = + PyFileDescriptor_AsDescriptor(descriptor); + if (!file_descriptor) { + return NULL; + } + if (file_descriptor != + reinterpret_cast<PyDescriptorPool*>(self)->pool->FindFileByName( + file_descriptor->name())) { + PyErr_Format(PyExc_ValueError, + "The file descriptor %s does not belong to this pool", + file_descriptor->name().c_str()); + return NULL; + } + Py_RETURN_NONE; +} + +static PyObject* AddDescriptor(PyObject* self, PyObject* descriptor) { + const Descriptor* message_descriptor = + PyMessageDescriptor_AsDescriptor(descriptor); + if (!message_descriptor) { + return NULL; + } + if (message_descriptor != + reinterpret_cast<PyDescriptorPool*>(self)->pool->FindMessageTypeByName( + message_descriptor->full_name())) { + PyErr_Format(PyExc_ValueError, + "The message descriptor %s does not belong to this pool", + message_descriptor->full_name().c_str()); + return NULL; + } + Py_RETURN_NONE; +} + +static PyObject* AddEnumDescriptor(PyObject* self, PyObject* descriptor) { + const EnumDescriptor* enum_descriptor = + PyEnumDescriptor_AsDescriptor(descriptor); + if (!enum_descriptor) { + return NULL; + } + if (enum_descriptor != + reinterpret_cast<PyDescriptorPool*>(self)->pool->FindEnumTypeByName( + enum_descriptor->full_name())) { + PyErr_Format(PyExc_ValueError, + "The enum descriptor %s does not belong to this pool", + enum_descriptor->full_name().c_str()); + return NULL; + } + Py_RETURN_NONE; +} + +static PyObject* AddExtensionDescriptor(PyObject* self, PyObject* descriptor) { + const FieldDescriptor* extension_descriptor = + PyFieldDescriptor_AsDescriptor(descriptor); + if (!extension_descriptor) { + return NULL; + } + if (extension_descriptor != + reinterpret_cast<PyDescriptorPool*>(self)->pool->FindExtensionByName( + extension_descriptor->full_name())) { + PyErr_Format(PyExc_ValueError, + "The extension descriptor %s does not belong to this pool", + extension_descriptor->full_name().c_str()); + return NULL; + } + Py_RETURN_NONE; +} + +static PyObject* AddServiceDescriptor(PyObject* self, PyObject* descriptor) { + const ServiceDescriptor* service_descriptor = + PyServiceDescriptor_AsDescriptor(descriptor); + if (!service_descriptor) { + return NULL; + } + if (service_descriptor != + reinterpret_cast<PyDescriptorPool*>(self)->pool->FindServiceByName( + service_descriptor->full_name())) { + PyErr_Format(PyExc_ValueError, + "The service descriptor %s does not belong to this pool", + service_descriptor->full_name().c_str()); + return NULL; + } + Py_RETURN_NONE; +} + +// The code below loads new Descriptors from a serialized FileDescriptorProto. +static PyObject* AddSerializedFile(PyObject* pself, PyObject* serialized_pb) { + PyDescriptorPool* self = reinterpret_cast<PyDescriptorPool*>(pself); + char* message_type; + Py_ssize_t message_len; + + if (self->database != NULL) { + PyErr_SetString( + PyExc_ValueError, + "Cannot call Add on a DescriptorPool that uses a DescriptorDatabase. " + "Add your file to the underlying database."); + return NULL; + } + + if (PyBytes_AsStringAndSize(serialized_pb, &message_type, &message_len) < 0) { + return NULL; + } + + FileDescriptorProto file_proto; + if (!file_proto.ParseFromArray(message_type, message_len)) { + PyErr_SetString(PyExc_TypeError, "Couldn't parse file content!"); + return NULL; + } + + // If the file was already part of a C++ library, all its descriptors are in + // the underlying pool. No need to do anything else. + const FileDescriptor* generated_file = NULL; + if (self->underlay) { + generated_file = self->underlay->FindFileByName(file_proto.name()); + } + if (generated_file != NULL) { + return PyFileDescriptor_FromDescriptorWithSerializedPb( + generated_file, serialized_pb); + } + + BuildFileErrorCollector error_collector; + const FileDescriptor* descriptor = + self->pool->BuildFileCollectingErrors(file_proto, + &error_collector); + if (descriptor == NULL) { + PyErr_Format(PyExc_TypeError, + "Couldn't build proto file into descriptor pool!\n%s", + error_collector.error_message.c_str()); + return NULL; + } + + return PyFileDescriptor_FromDescriptorWithSerializedPb( + descriptor, serialized_pb); +} + +static PyObject* Add(PyObject* self, PyObject* file_descriptor_proto) { + ScopedPyObjectPtr serialized_pb( + PyObject_CallMethod(file_descriptor_proto, "SerializeToString", NULL)); + if (serialized_pb == NULL) { + return NULL; + } + return AddSerializedFile(self, serialized_pb.get()); +} + +static PyMethodDef Methods[] = { + { "Add", Add, METH_O, + "Adds the FileDescriptorProto and its types to this pool." }, + { "AddSerializedFile", AddSerializedFile, METH_O, + "Adds a serialized FileDescriptorProto to this pool." }, + + // TODO(amauryfa): Understand why the Python implementation differs from + // this one, ask users to use another API and deprecate these functions. + { "AddFileDescriptor", AddFileDescriptor, METH_O, + "No-op. Add() must have been called before." }, + { "AddDescriptor", AddDescriptor, METH_O, + "No-op. Add() must have been called before." }, + { "AddEnumDescriptor", AddEnumDescriptor, METH_O, + "No-op. Add() must have been called before." }, + { "AddExtensionDescriptor", AddExtensionDescriptor, METH_O, + "No-op. Add() must have been called before." }, + { "AddServiceDescriptor", AddServiceDescriptor, METH_O, + "No-op. Add() must have been called before." }, + + { "FindFileByName", FindFileByName, METH_O, + "Searches for a file descriptor by its .proto name." }, + { "FindMessageTypeByName", FindMessageByName, METH_O, + "Searches for a message descriptor by full name." }, + { "FindFieldByName", FindFieldByNameMethod, METH_O, + "Searches for a field descriptor by full name." }, + { "FindExtensionByName", FindExtensionByNameMethod, METH_O, + "Searches for extension descriptor by full name." }, + { "FindEnumTypeByName", FindEnumTypeByNameMethod, METH_O, + "Searches for enum type descriptor by full name." }, + { "FindOneofByName", FindOneofByNameMethod, METH_O, + "Searches for oneof descriptor by full name." }, + { "FindServiceByName", FindServiceByName, METH_O, + "Searches for service descriptor by full name." }, + { "FindMethodByName", FindMethodByName, METH_O, + "Searches for method descriptor by full name." }, + + { "FindFileContainingSymbol", FindFileContainingSymbol, METH_O, + "Gets the FileDescriptor containing the specified symbol." }, + { "FindExtensionByNumber", FindExtensionByNumber, METH_VARARGS, + "Gets the extension descriptor for the given number." }, + { "FindAllExtensions", FindAllExtensions, METH_O, + "Gets all known extensions of the given message descriptor." }, + {NULL} +}; + +} // namespace cdescriptor_pool + +PyTypeObject PyDescriptorPool_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME + ".DescriptorPool", // tp_name + sizeof(PyDescriptorPool), // tp_basicsize + 0, // tp_itemsize + cdescriptor_pool::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, // tp_flags + "A Descriptor Pool", // tp_doc + cdescriptor_pool::GcTraverse, // tp_traverse + cdescriptor_pool::GcClear, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + cdescriptor_pool::Methods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + cdescriptor_pool::New, // tp_new + PyObject_GC_Del, // tp_free +}; + +// This is the DescriptorPool which contains all the definitions from the +// generated _pb2.py modules. +static PyDescriptorPool* python_generated_pool = NULL; + +bool InitDescriptorPool() { + if (PyType_Ready(&PyDescriptorPool_Type) < 0) + return false; + + // The Pool of messages declared in Python libraries. + // generated_pool() contains all messages already linked in C++ libraries, and + // is used as underlay. + descriptor_pool_map = + new std::unordered_map<const DescriptorPool*, PyDescriptorPool*>; + python_generated_pool = cdescriptor_pool::PyDescriptorPool_NewWithUnderlay( + DescriptorPool::generated_pool()); + if (python_generated_pool == NULL) { + delete descriptor_pool_map; + return false; + } + + // Register this pool to be found for C++-generated descriptors. + descriptor_pool_map->insert( + std::make_pair(DescriptorPool::generated_pool(), + python_generated_pool)); + + return true; +} + +// The default DescriptorPool used everywhere in this module. +// Today it's the python_generated_pool. +// TODO(amauryfa): Remove all usages of this function: the pool should be +// derived from the context. +PyDescriptorPool* GetDefaultDescriptorPool() { + return python_generated_pool; +} + +PyDescriptorPool* GetDescriptorPool_FromPool(const DescriptorPool* pool) { + // Fast path for standard descriptors. + if (pool == python_generated_pool->pool || + pool == DescriptorPool::generated_pool()) { + return python_generated_pool; + } + std::unordered_map<const DescriptorPool*, PyDescriptorPool*>::iterator it = + descriptor_pool_map->find(pool); + if (it == descriptor_pool_map->end()) { + PyErr_SetString(PyExc_KeyError, "Unknown descriptor pool"); + return NULL; + } + return it->second; +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor_pool.h b/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor_pool.h new file mode 100644 index 0000000000..2d456f9088 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/descriptor_pool.h @@ -0,0 +1,136 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_POOL_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_POOL_H__ + +#include <Python.h> + +#include <unordered_map> +#include <google/protobuf/descriptor.h> + +namespace google { +namespace protobuf { +namespace python { + +struct PyMessageFactory; + +// The (meta) type of all Messages classes. +struct CMessageClass; + +// Wraps operations to the global DescriptorPool which contains information +// about all messages and fields. +// +// There is normally one pool per process. We make it a Python object only +// because it contains many Python references. +// +// "Methods" that interacts with this DescriptorPool are in the cdescriptor_pool +// namespace. +typedef struct PyDescriptorPool { + PyObject_HEAD + + // The C++ pool containing Descriptors. + DescriptorPool* pool; + + // The error collector to store error info. Can be NULL. This pointer is + // owned. + DescriptorPool::ErrorCollector* error_collector; + + // The C++ pool acting as an underlay. Can be NULL. + // This pointer is not owned and must stay alive. + const DescriptorPool* underlay; + + // The C++ descriptor database used to fetch unknown protos. Can be NULL. + // This pointer is owned. + const DescriptorDatabase* database; + + // The preferred MessageFactory to be used by descriptors. + // TODO(amauryfa): Don't create the Factory from the DescriptorPool, but + // use the one passed while creating message classes. And remove this member. + PyMessageFactory* py_message_factory; + + // Cache the options for any kind of descriptor. + // Descriptor pointers are owned by the DescriptorPool above. + // Python objects are owned by the map. + std::unordered_map<const void*, PyObject*>* descriptor_options; +} PyDescriptorPool; + + +extern PyTypeObject PyDescriptorPool_Type; + +namespace cdescriptor_pool { + + +// The functions below are also exposed as methods of the DescriptorPool type. + +// Looks up a field by name. Returns a PyFieldDescriptor corresponding to +// the field on success, or NULL on failure. +// +// Returns a new reference. +PyObject* FindFieldByName(PyDescriptorPool* self, PyObject* name); + +// Looks up an extension by name. Returns a PyFieldDescriptor corresponding +// to the field on success, or NULL on failure. +// +// Returns a new reference. +PyObject* FindExtensionByName(PyDescriptorPool* self, PyObject* arg); + +// Looks up an enum type by name. Returns a PyEnumDescriptor corresponding +// to the field on success, or NULL on failure. +// +// Returns a new reference. +PyObject* FindEnumTypeByName(PyDescriptorPool* self, PyObject* arg); + +// Looks up a oneof by name. Returns a COneofDescriptor corresponding +// to the oneof on success, or NULL on failure. +// +// Returns a new reference. +PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg); + +} // namespace cdescriptor_pool + +// Retrieve the global descriptor pool owned by the _message module. +// This is the one used by pb2.py generated modules. +// Returns a *borrowed* reference. +// "Default" pool used to register messages from _pb2.py modules. +PyDescriptorPool* GetDefaultDescriptorPool(); + +// Retrieve the python descriptor pool owning a C++ descriptor pool. +// Returns a *borrowed* reference. +PyDescriptorPool* GetDescriptorPool_FromPool(const DescriptorPool* pool); + +// Initialize objects used by this module. +bool InitDescriptorPool(); + +} // namespace python +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_POOL_H__ diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/extension_dict.cc b/contrib/python/protobuf/py3/google/protobuf/pyext/extension_dict.cc new file mode 100644 index 0000000000..37b414c375 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/extension_dict.cc @@ -0,0 +1,480 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#include <google/protobuf/pyext/extension_dict.h> + +#include <cstdint> +#include <memory> + +#include <google/protobuf/stubs/logging.h> +#include <google/protobuf/stubs/common.h> +#include <google/protobuf/descriptor.pb.h> +#include <google/protobuf/descriptor.h> +#include <google/protobuf/dynamic_message.h> +#include <google/protobuf/message.h> +#include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/message_factory.h> +#include <google/protobuf/pyext/repeated_composite_container.h> +#include <google/protobuf/pyext/repeated_scalar_container.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> + +#if PY_MAJOR_VERSION >= 3 + #if PY_VERSION_HEX < 0x03030000 + #error "Python 3.0 - 3.2 are not supported." + #endif +#define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>( \ + PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \ + ? -1 \ + : 0) \ + : PyBytes_AsStringAndSize(ob, (charpp), (sizep))) +#endif + +namespace google { +namespace protobuf { +namespace python { + +namespace extension_dict { + +static Py_ssize_t len(ExtensionDict* self) { + Py_ssize_t size = 0; + std::vector<const FieldDescriptor*> fields; + self->parent->message->GetReflection()->ListFields(*self->parent->message, + &fields); + + for (size_t i = 0; i < fields.size(); ++i) { + if (fields[i]->is_extension()) { + // With C++ descriptors, the field can always be retrieved, but for + // unknown extensions which have not been imported in Python code, there + // is no message class and we cannot retrieve the value. + // ListFields() has the same behavior. + if (fields[i]->message_type() != nullptr && + message_factory::GetMessageClass( + cmessage::GetFactoryForMessage(self->parent), + fields[i]->message_type()) == nullptr) { + PyErr_Clear(); + continue; + } + ++size; + } + } + return size; +} + +struct ExtensionIterator { + PyObject_HEAD; + Py_ssize_t index; + std::vector<const FieldDescriptor*> fields; + + // Owned reference, to keep the FieldDescriptors alive. + ExtensionDict* extension_dict; +}; + +PyObject* GetIter(PyObject* _self) { + ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self); + + ScopedPyObjectPtr obj(PyType_GenericAlloc(&ExtensionIterator_Type, 0)); + if (obj == nullptr) { + return PyErr_Format(PyExc_MemoryError, + "Could not allocate extension iterator"); + } + + ExtensionIterator* iter = reinterpret_cast<ExtensionIterator*>(obj.get()); + + // Call "placement new" to initialize. So the constructor of + // std::vector<...> fields will be called. + new (iter) ExtensionIterator; + + self->parent->message->GetReflection()->ListFields(*self->parent->message, + &iter->fields); + iter->index = 0; + Py_INCREF(self); + iter->extension_dict = self; + + return obj.release(); +} + +static void DeallocExtensionIterator(PyObject* _self) { + ExtensionIterator* self = reinterpret_cast<ExtensionIterator*>(_self); + self->fields.clear(); + Py_XDECREF(self->extension_dict); + self->~ExtensionIterator(); + Py_TYPE(_self)->tp_free(_self); +} + +PyObject* subscript(ExtensionDict* self, PyObject* key) { + const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key); + if (descriptor == NULL) { + return NULL; + } + if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) { + return NULL; + } + + if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && + descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) { + return cmessage::InternalGetScalar(self->parent->message, descriptor); + } + + CMessage::CompositeFieldsMap::iterator iterator = + self->parent->composite_fields->find(descriptor); + if (iterator != self->parent->composite_fields->end()) { + Py_INCREF(iterator->second); + return iterator->second->AsPyObject(); + } + + if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && + descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + // TODO(plabatut): consider building the class on the fly! + ContainerBase* sub_message = cmessage::InternalGetSubMessage( + self->parent, descriptor); + if (sub_message == NULL) { + return NULL; + } + (*self->parent->composite_fields)[descriptor] = sub_message; + return sub_message->AsPyObject(); + } + + if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { + if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + // On the fly message class creation is needed to support the following + // situation: + // 1- add FileDescriptor to the pool that contains extensions of a message + // defined by another proto file. Do not create any message classes. + // 2- instantiate an extended message, and access the extension using + // the field descriptor. + // 3- the extension submessage fails to be returned, because no class has + // been created. + // It happens when deserializing text proto format, or when enumerating + // fields of a deserialized message. + CMessageClass* message_class = message_factory::GetOrCreateMessageClass( + cmessage::GetFactoryForMessage(self->parent), + descriptor->message_type()); + ScopedPyObjectPtr message_class_handler( + reinterpret_cast<PyObject*>(message_class)); + if (message_class == NULL) { + return NULL; + } + ContainerBase* py_container = repeated_composite_container::NewContainer( + self->parent, descriptor, message_class); + if (py_container == NULL) { + return NULL; + } + (*self->parent->composite_fields)[descriptor] = py_container; + return py_container->AsPyObject(); + } else { + ContainerBase* py_container = repeated_scalar_container::NewContainer( + self->parent, descriptor); + if (py_container == NULL) { + return NULL; + } + (*self->parent->composite_fields)[descriptor] = py_container; + return py_container->AsPyObject(); + } + } + PyErr_SetString(PyExc_ValueError, "control reached unexpected line"); + return NULL; +} + +int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { + const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key); + if (descriptor == NULL) { + return -1; + } + if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) { + return -1; + } + + if (value == nullptr) { + return cmessage::ClearFieldByDescriptor(self->parent, descriptor); + } + + if (descriptor->label() != FieldDescriptor::LABEL_OPTIONAL || + descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + PyErr_SetString(PyExc_TypeError, "Extension is repeated and/or composite " + "type"); + return -1; + } + cmessage::AssureWritable(self->parent); + if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) { + return -1; + } + return 0; +} + +PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) { + char* name; + Py_ssize_t name_size; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool; + const FieldDescriptor* message_extension = + pool->pool->FindExtensionByName(StringParam(name, name_size)); + if (message_extension == NULL) { + // Is is the name of a message set extension? + const Descriptor* message_descriptor = + pool->pool->FindMessageTypeByName(StringParam(name, name_size)); + if (message_descriptor && message_descriptor->extension_count() > 0) { + const FieldDescriptor* extension = message_descriptor->extension(0); + if (extension->is_extension() && + extension->containing_type()->options().message_set_wire_format() && + extension->type() == FieldDescriptor::TYPE_MESSAGE && + extension->label() == FieldDescriptor::LABEL_OPTIONAL) { + message_extension = extension; + } + } + } + if (message_extension == NULL) { + Py_RETURN_NONE; + } + + return PyFieldDescriptor_FromDescriptor(message_extension); +} + +PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* arg) { + int64_t number = PyLong_AsLong(arg); + if (number == -1 && PyErr_Occurred()) { + return NULL; + } + + PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool; + const FieldDescriptor* message_extension = pool->pool->FindExtensionByNumber( + self->parent->message->GetDescriptor(), number); + if (message_extension == NULL) { + Py_RETURN_NONE; + } + + return PyFieldDescriptor_FromDescriptor(message_extension); +} + +static int Contains(PyObject* _self, PyObject* key) { + ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self); + const FieldDescriptor* field_descriptor = + cmessage::GetExtensionDescriptor(key); + if (field_descriptor == nullptr) { + return -1; + } + + if (!field_descriptor->is_extension()) { + PyErr_Format(PyExc_KeyError, "%s is not an extension", + field_descriptor->full_name().c_str()); + return -1; + } + + const Message* message = self->parent->message; + const Reflection* reflection = message->GetReflection(); + if (field_descriptor->is_repeated()) { + if (reflection->FieldSize(*message, field_descriptor) > 0) { + return 1; + } + } else { + if (reflection->HasField(*message, field_descriptor)) { + return 1; + } + } + + return 0; +} + +ExtensionDict* NewExtensionDict(CMessage *parent) { + ExtensionDict* self = reinterpret_cast<ExtensionDict*>( + PyType_GenericAlloc(&ExtensionDict_Type, 0)); + if (self == NULL) { + return NULL; + } + + Py_INCREF(parent); + self->parent = parent; + return self; +} + +void dealloc(PyObject* pself) { + ExtensionDict* self = reinterpret_cast<ExtensionDict*>(pself); + Py_CLEAR(self->parent); + Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); +} + +static PyObject* RichCompare(ExtensionDict* self, PyObject* other, int opid) { + // Only equality comparisons are implemented. + if (opid != Py_EQ && opid != Py_NE) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + bool equals = false; + if (PyObject_TypeCheck(other, &ExtensionDict_Type)) { + equals = self->parent == reinterpret_cast<ExtensionDict*>(other)->parent;; + } + if (equals ^ (opid == Py_EQ)) { + Py_RETURN_FALSE; + } else { + Py_RETURN_TRUE; + } +} +static PySequenceMethods SeqMethods = { + (lenfunc)len, // sq_length + 0, // sq_concat + 0, // sq_repeat + 0, // sq_item + 0, // sq_slice + 0, // sq_ass_item + 0, // sq_ass_slice + (objobjproc)Contains, // sq_contains +}; + +static PyMappingMethods MpMethods = { + (lenfunc)len, /* mp_length */ + (binaryfunc)subscript, /* mp_subscript */ + (objobjargproc)ass_subscript,/* mp_ass_subscript */ +}; + +#define EDMETHOD(name, args, doc) { #name, (PyCFunction)name, args, doc } +static PyMethodDef Methods[] = { + EDMETHOD(_FindExtensionByName, METH_O, "Finds an extension by name."), + EDMETHOD(_FindExtensionByNumber, METH_O, + "Finds an extension by field number."), + {NULL, NULL}, +}; + +} // namespace extension_dict + +PyTypeObject ExtensionDict_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) // + FULL_MODULE_NAME ".ExtensionDict", // tp_name + sizeof(ExtensionDict), // tp_basicsize + 0, // tp_itemsize + (destructor)extension_dict::dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + &extension_dict::SeqMethods, // tp_as_sequence + &extension_dict::MpMethods, // tp_as_mapping + PyObject_HashNotImplemented, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "An extension dict", // tp_doc + 0, // tp_traverse + 0, // tp_clear + (richcmpfunc)extension_dict::RichCompare, // tp_richcompare + 0, // tp_weaklistoffset + extension_dict::GetIter, // tp_iter + 0, // tp_iternext + extension_dict::Methods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init +}; + +PyObject* IterNext(PyObject* _self) { + extension_dict::ExtensionIterator* self = + reinterpret_cast<extension_dict::ExtensionIterator*>(_self); + Py_ssize_t total_size = self->fields.size(); + Py_ssize_t index = self->index; + while (self->index < total_size) { + index = self->index; + ++self->index; + if (self->fields[index]->is_extension()) { + // With C++ descriptors, the field can always be retrieved, but for + // unknown extensions which have not been imported in Python code, there + // is no message class and we cannot retrieve the value. + // ListFields() has the same behavior. + if (self->fields[index]->message_type() != nullptr && + message_factory::GetMessageClass( + cmessage::GetFactoryForMessage(self->extension_dict->parent), + self->fields[index]->message_type()) == nullptr) { + PyErr_Clear(); + continue; + } + + return PyFieldDescriptor_FromDescriptor(self->fields[index]); + } + } + + return nullptr; +} + +PyTypeObject ExtensionIterator_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) // + FULL_MODULE_NAME ".ExtensionIterator", // tp_name + sizeof(extension_dict::ExtensionIterator), // tp_basicsize + 0, // tp_itemsize + extension_dict::DeallocExtensionIterator, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A scalar map iterator", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + PyObject_SelfIter, // tp_iter + IterNext, // tp_iternext + 0, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init +}; +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/extension_dict.h b/contrib/python/protobuf/py3/google/protobuf/pyext/extension_dict.h new file mode 100644 index 0000000000..c9da443161 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/extension_dict.h @@ -0,0 +1,71 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_EXTENSION_DICT_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_EXTENSION_DICT_H__ + +#include <Python.h> + +#include <memory> + +#include <google/protobuf/pyext/message.h> + +namespace google { +namespace protobuf { + +class Message; +class FieldDescriptor; + +namespace python { + +typedef struct ExtensionDict { + PyObject_HEAD; + + // Strong, owned reference to the parent message. Never NULL. + CMessage* parent; +} ExtensionDict; + +extern PyTypeObject ExtensionDict_Type; +extern PyTypeObject ExtensionIterator_Type; + +namespace extension_dict { + +// Builds an Extensions dict for a specific message. +ExtensionDict* NewExtensionDict(CMessage *parent); + +} // namespace extension_dict +} // namespace python +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_EXTENSION_DICT_H__ diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/field.cc b/contrib/python/protobuf/py3/google/protobuf/pyext/field.cc new file mode 100644 index 0000000000..1afd4583b3 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/field.cc @@ -0,0 +1,142 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include <google/protobuf/pyext/field.h> + +#include <google/protobuf/descriptor.h> +#include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/message.h> + +#if PY_MAJOR_VERSION >= 3 + #define PyString_FromFormat PyUnicode_FromFormat +#endif + +namespace google { +namespace protobuf { +namespace python { + +namespace field { + +static PyObject* Repr(PyMessageFieldProperty* self) { + return PyString_FromFormat("<field property '%s'>", + self->field_descriptor->full_name().c_str()); +} + +static PyObject* DescrGet(PyMessageFieldProperty* self, PyObject* obj, + PyObject* type) { + if (obj == NULL) { + Py_INCREF(self); + return reinterpret_cast<PyObject*>(self); + } + return cmessage::GetFieldValue(reinterpret_cast<CMessage*>(obj), + self->field_descriptor); +} + +static int DescrSet(PyMessageFieldProperty* self, PyObject* obj, + PyObject* value) { + if (value == NULL) { + PyErr_SetString(PyExc_AttributeError, "Cannot delete field attribute"); + return -1; + } + return cmessage::SetFieldValue(reinterpret_cast<CMessage*>(obj), + self->field_descriptor, value); +} + +static PyObject* GetDescriptor(PyMessageFieldProperty* self, void* closure) { + return PyFieldDescriptor_FromDescriptor(self->field_descriptor); +} + +static PyObject* GetDoc(PyMessageFieldProperty* self, void* closure) { + return PyString_FromFormat("Field %s", + self->field_descriptor->full_name().c_str()); +} + +static PyGetSetDef Getters[] = { + {"DESCRIPTOR", (getter)GetDescriptor, NULL, "Field descriptor"}, + {"__doc__", (getter)GetDoc, NULL, NULL}, + {NULL}}; +} // namespace field + +static PyTypeObject _CFieldProperty_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) // head + FULL_MODULE_NAME ".FieldProperty", // tp_name + sizeof(PyMessageFieldProperty), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + (reprfunc)field::Repr, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "Field property of a Message", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + 0, // tp_methods + 0, // tp_members + field::Getters, // tp_getset + 0, // tp_base + 0, // tp_dict + (descrgetfunc)field::DescrGet, // tp_descr_get + (descrsetfunc)field::DescrSet, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + 0, // tp_new +}; +PyTypeObject* CFieldProperty_Type = &_CFieldProperty_Type; + +PyObject* NewFieldProperty(const FieldDescriptor* field_descriptor) { + // Create a new descriptor object + PyMessageFieldProperty* property = + PyObject_New(PyMessageFieldProperty, CFieldProperty_Type); + if (property == NULL) { + return NULL; + } + property->field_descriptor = field_descriptor; + return reinterpret_cast<PyObject*>(property); +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/field.h b/contrib/python/protobuf/py3/google/protobuf/pyext/field.h new file mode 100644 index 0000000000..7b4660cab5 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/field.h @@ -0,0 +1,59 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_FIELD_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_FIELD_H__ + +#include <Python.h> + +namespace google { +namespace protobuf { + +class FieldDescriptor; + +namespace python { + +// A data descriptor that represents a field in a Message class. +struct PyMessageFieldProperty { + PyObject_HEAD; + + // This pointer is owned by the same pool as the Message class it belongs to. + const FieldDescriptor* field_descriptor; +}; + +extern PyTypeObject* CFieldProperty_Type; + +PyObject* NewFieldProperty(const FieldDescriptor* field_descriptor); + +} // namespace python +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_FIELD_H__ diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/map_container.cc b/contrib/python/protobuf/py3/google/protobuf/pyext/map_container.cc new file mode 100644 index 0000000000..e7a9cca23b --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/map_container.cc @@ -0,0 +1,1059 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: haberman@google.com (Josh Haberman) + +#include <google/protobuf/pyext/map_container.h> + +#include <cstdint> +#include <memory> + +#include <google/protobuf/stubs/logging.h> +#include <google/protobuf/stubs/common.h> +#include <google/protobuf/map.h> +#include <google/protobuf/map_field.h> +#include <google/protobuf/message.h> +#include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/message_factory.h> +#include <google/protobuf/pyext/repeated_composite_container.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> +#include <google/protobuf/stubs/map_util.h> + +#if PY_MAJOR_VERSION >= 3 + #define PyInt_FromLong PyLong_FromLong + #define PyInt_FromSize_t PyLong_FromSize_t +#endif + +namespace google { +namespace protobuf { +namespace python { + +// Functions that need access to map reflection functionality. +// They need to be contained in this class because it is friended. +class MapReflectionFriend { + public: + // Methods that are in common between the map types. + static PyObject* Contains(PyObject* _self, PyObject* key); + static Py_ssize_t Length(PyObject* _self); + static PyObject* GetIterator(PyObject *_self); + static PyObject* IterNext(PyObject* _self); + static PyObject* MergeFrom(PyObject* _self, PyObject* arg); + + // Methods that differ between the map types. + static PyObject* ScalarMapGetItem(PyObject* _self, PyObject* key); + static PyObject* MessageMapGetItem(PyObject* _self, PyObject* key); + static int ScalarMapSetItem(PyObject* _self, PyObject* key, PyObject* v); + static int MessageMapSetItem(PyObject* _self, PyObject* key, PyObject* v); + static PyObject* ScalarMapToStr(PyObject* _self); + static PyObject* MessageMapToStr(PyObject* _self); +}; + +struct MapIterator { + PyObject_HEAD; + + std::unique_ptr<::google::protobuf::MapIterator> iter; + + // A pointer back to the container, so we can notice changes to the version. + // We own a ref on this. + MapContainer* container; + + // We need to keep a ref on the parent Message too, because + // MapIterator::~MapIterator() accesses it. Normally this would be ok because + // the ref on container (above) would guarantee outlive semantics. However in + // the case of ClearField(), the MapContainer points to a different message, + // a copy of the original. But our iterator still points to the original, + // which could now get deleted before us. + // + // To prevent this, we ensure that the Message will always stay alive as long + // as this iterator does. This is solely for the benefit of the MapIterator + // destructor -- we should never actually access the iterator in this state + // except to delete it. + CMessage* parent; + // The version of the map when we took the iterator to it. + // + // We store this so that if the map is modified during iteration we can throw + // an error. + uint64_t version; +}; + +Message* MapContainer::GetMutableMessage() { + cmessage::AssureWritable(parent); + return parent->message; +} + +// Consumes a reference on the Python string object. +static bool PyStringToSTL(PyObject* py_string, TProtoStringType* stl_string) { + char *value; + Py_ssize_t value_len; + + if (!py_string) { + return false; + } + if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) { + Py_DECREF(py_string); + return false; + } else { + stl_string->assign(value, value_len); + Py_DECREF(py_string); + return true; + } +} + +static bool PythonToMapKey(MapContainer* self, PyObject* obj, MapKey* key) { + const FieldDescriptor* field_descriptor = + self->parent_field_descriptor->message_type()->map_key(); + switch (field_descriptor->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: { + GOOGLE_CHECK_GET_INT32(obj, value, false); + key->SetInt32Value(value); + break; + } + case FieldDescriptor::CPPTYPE_INT64: { + GOOGLE_CHECK_GET_INT64(obj, value, false); + key->SetInt64Value(value); + break; + } + case FieldDescriptor::CPPTYPE_UINT32: { + GOOGLE_CHECK_GET_UINT32(obj, value, false); + key->SetUInt32Value(value); + break; + } + case FieldDescriptor::CPPTYPE_UINT64: { + GOOGLE_CHECK_GET_UINT64(obj, value, false); + key->SetUInt64Value(value); + break; + } + case FieldDescriptor::CPPTYPE_BOOL: { + GOOGLE_CHECK_GET_BOOL(obj, value, false); + key->SetBoolValue(value); + break; + } + case FieldDescriptor::CPPTYPE_STRING: { + TProtoStringType str; + if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) { + return false; + } + key->SetStringValue(str); + break; + } + default: + PyErr_Format( + PyExc_SystemError, "Type %d cannot be a map key", + field_descriptor->cpp_type()); + return false; + } + return true; +} + +static PyObject* MapKeyToPython(MapContainer* self, const MapKey& key) { + const FieldDescriptor* field_descriptor = + self->parent_field_descriptor->message_type()->map_key(); + switch (field_descriptor->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return PyInt_FromLong(key.GetInt32Value()); + case FieldDescriptor::CPPTYPE_INT64: + return PyLong_FromLongLong(key.GetInt64Value()); + case FieldDescriptor::CPPTYPE_UINT32: + return PyInt_FromSize_t(key.GetUInt32Value()); + case FieldDescriptor::CPPTYPE_UINT64: + return PyLong_FromUnsignedLongLong(key.GetUInt64Value()); + case FieldDescriptor::CPPTYPE_BOOL: + return PyBool_FromLong(key.GetBoolValue()); + case FieldDescriptor::CPPTYPE_STRING: + return ToStringObject(field_descriptor, key.GetStringValue()); + default: + PyErr_Format( + PyExc_SystemError, "Couldn't convert type %d to value", + field_descriptor->cpp_type()); + return NULL; + } +} + +// This is only used for ScalarMap, so we don't need to handle the +// CPPTYPE_MESSAGE case. +PyObject* MapValueRefToPython(MapContainer* self, const MapValueRef& value) { + const FieldDescriptor* field_descriptor = + self->parent_field_descriptor->message_type()->map_value(); + switch (field_descriptor->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return PyInt_FromLong(value.GetInt32Value()); + case FieldDescriptor::CPPTYPE_INT64: + return PyLong_FromLongLong(value.GetInt64Value()); + case FieldDescriptor::CPPTYPE_UINT32: + return PyInt_FromSize_t(value.GetUInt32Value()); + case FieldDescriptor::CPPTYPE_UINT64: + return PyLong_FromUnsignedLongLong(value.GetUInt64Value()); + case FieldDescriptor::CPPTYPE_FLOAT: + return PyFloat_FromDouble(value.GetFloatValue()); + case FieldDescriptor::CPPTYPE_DOUBLE: + return PyFloat_FromDouble(value.GetDoubleValue()); + case FieldDescriptor::CPPTYPE_BOOL: + return PyBool_FromLong(value.GetBoolValue()); + case FieldDescriptor::CPPTYPE_STRING: + return ToStringObject(field_descriptor, value.GetStringValue()); + case FieldDescriptor::CPPTYPE_ENUM: + return PyInt_FromLong(value.GetEnumValue()); + default: + PyErr_Format( + PyExc_SystemError, "Couldn't convert type %d to value", + field_descriptor->cpp_type()); + return NULL; + } +} + +// This is only used for ScalarMap, so we don't need to handle the +// CPPTYPE_MESSAGE case. +static bool PythonToMapValueRef(MapContainer* self, PyObject* obj, + bool allow_unknown_enum_values, + MapValueRef* value_ref) { + const FieldDescriptor* field_descriptor = + self->parent_field_descriptor->message_type()->map_value(); + switch (field_descriptor->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: { + GOOGLE_CHECK_GET_INT32(obj, value, false); + value_ref->SetInt32Value(value); + return true; + } + case FieldDescriptor::CPPTYPE_INT64: { + GOOGLE_CHECK_GET_INT64(obj, value, false); + value_ref->SetInt64Value(value); + return true; + } + case FieldDescriptor::CPPTYPE_UINT32: { + GOOGLE_CHECK_GET_UINT32(obj, value, false); + value_ref->SetUInt32Value(value); + return true; + } + case FieldDescriptor::CPPTYPE_UINT64: { + GOOGLE_CHECK_GET_UINT64(obj, value, false); + value_ref->SetUInt64Value(value); + return true; + } + case FieldDescriptor::CPPTYPE_FLOAT: { + GOOGLE_CHECK_GET_FLOAT(obj, value, false); + value_ref->SetFloatValue(value); + return true; + } + case FieldDescriptor::CPPTYPE_DOUBLE: { + GOOGLE_CHECK_GET_DOUBLE(obj, value, false); + value_ref->SetDoubleValue(value); + return true; + } + case FieldDescriptor::CPPTYPE_BOOL: { + GOOGLE_CHECK_GET_BOOL(obj, value, false); + value_ref->SetBoolValue(value); + return true;; + } + case FieldDescriptor::CPPTYPE_STRING: { + TProtoStringType str; + if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) { + return false; + } + value_ref->SetStringValue(str); + return true; + } + case FieldDescriptor::CPPTYPE_ENUM: { + GOOGLE_CHECK_GET_INT32(obj, value, false); + if (allow_unknown_enum_values) { + value_ref->SetEnumValue(value); + return true; + } else { + const EnumDescriptor* enum_descriptor = field_descriptor->enum_type(); + const EnumValueDescriptor* enum_value = + enum_descriptor->FindValueByNumber(value); + if (enum_value != NULL) { + value_ref->SetEnumValue(value); + return true; + } else { + PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value); + return false; + } + } + break; + } + default: + PyErr_Format( + PyExc_SystemError, "Setting value to a field of unknown type %d", + field_descriptor->cpp_type()); + return false; + } +} + +// Map methods common to ScalarMap and MessageMap ////////////////////////////// + +static MapContainer* GetMap(PyObject* obj) { + return reinterpret_cast<MapContainer*>(obj); +} + +Py_ssize_t MapReflectionFriend::Length(PyObject* _self) { + MapContainer* self = GetMap(_self); + const google::protobuf::Message* message = self->parent->message; + return message->GetReflection()->MapSize(*message, + self->parent_field_descriptor); +} + +PyObject* Clear(PyObject* _self) { + MapContainer* self = GetMap(_self); + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + + reflection->ClearField(message, self->parent_field_descriptor); + + Py_RETURN_NONE; +} + +PyObject* GetEntryClass(PyObject* _self) { + MapContainer* self = GetMap(_self); + CMessageClass* message_class = message_factory::GetMessageClass( + cmessage::GetFactoryForMessage(self->parent), + self->parent_field_descriptor->message_type()); + Py_XINCREF(message_class); + return reinterpret_cast<PyObject*>(message_class); +} + +PyObject* MapReflectionFriend::MergeFrom(PyObject* _self, PyObject* arg) { + MapContainer* self = GetMap(_self); + if (!PyObject_TypeCheck(arg, ScalarMapContainer_Type) && + !PyObject_TypeCheck(arg, MessageMapContainer_Type)) { + PyErr_SetString(PyExc_AttributeError, "Not a map field"); + return nullptr; + } + MapContainer* other_map = GetMap(arg); + Message* message = self->GetMutableMessage(); + const Message* other_message = other_map->parent->message; + const Reflection* reflection = message->GetReflection(); + const Reflection* other_reflection = other_message->GetReflection(); + internal::MapFieldBase* field = reflection->MutableMapData( + message, self->parent_field_descriptor); + const internal::MapFieldBase* other_field = other_reflection->GetMapData( + *other_message, other_map->parent_field_descriptor); + field->MergeFrom(*other_field); + self->version++; + Py_RETURN_NONE; +} + +PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) { + MapContainer* self = GetMap(_self); + + const Message* message = self->parent->message; + const Reflection* reflection = message->GetReflection(); + MapKey map_key; + + if (!PythonToMapKey(self, key, &map_key)) { + return NULL; + } + + if (reflection->ContainsMapKey(*message, self->parent_field_descriptor, + map_key)) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + +// ScalarMap /////////////////////////////////////////////////////////////////// + +MapContainer* NewScalarMapContainer( + CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) { + if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) { + return NULL; + } + + PyObject* obj(PyType_GenericAlloc(ScalarMapContainer_Type, 0)); + if (obj == NULL) { + PyErr_Format(PyExc_RuntimeError, + "Could not allocate new container."); + return NULL; + } + + MapContainer* self = GetMap(obj); + + Py_INCREF(parent); + self->parent = parent; + self->parent_field_descriptor = parent_field_descriptor; + self->version = 0; + + return self; +} + +PyObject* MapReflectionFriend::ScalarMapGetItem(PyObject* _self, + PyObject* key) { + MapContainer* self = GetMap(_self); + + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + MapKey map_key; + MapValueRef value; + + if (!PythonToMapKey(self, key, &map_key)) { + return NULL; + } + + if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor, + map_key, &value)) { + self->version++; + } + + return MapValueRefToPython(self, value); +} + +int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key, + PyObject* v) { + MapContainer* self = GetMap(_self); + + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + MapKey map_key; + MapValueRef value; + + if (!PythonToMapKey(self, key, &map_key)) { + return -1; + } + + self->version++; + + if (v) { + // Set item to v. + reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor, + map_key, &value); + + if (!PythonToMapValueRef(self, v, reflection->SupportsUnknownEnumValues(), + &value)) { + return -1; + } + return 0; + } else { + // Delete key from map. + if (reflection->DeleteMapValue(message, self->parent_field_descriptor, + map_key)) { + return 0; + } else { + PyErr_Format(PyExc_KeyError, "Key not present in map"); + return -1; + } + } +} + +static PyObject* ScalarMapGet(PyObject* self, PyObject* args, + PyObject* kwargs) { + static const char* kwlist[] = {"key", "default", nullptr}; + PyObject* key; + PyObject* default_value = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O", + const_cast<char**>(kwlist), &key, + &default_value)) { + return NULL; + } + + ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key)); + if (is_present.get() == NULL) { + return NULL; + } + + if (PyObject_IsTrue(is_present.get())) { + return MapReflectionFriend::ScalarMapGetItem(self, key); + } else { + if (default_value != NULL) { + Py_INCREF(default_value); + return default_value; + } else { + Py_RETURN_NONE; + } + } +} + +PyObject* MapReflectionFriend::ScalarMapToStr(PyObject* _self) { + ScopedPyObjectPtr dict(PyDict_New()); + if (dict == NULL) { + return NULL; + } + ScopedPyObjectPtr key; + ScopedPyObjectPtr value; + + MapContainer* self = GetMap(_self); + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + for (google::protobuf::MapIterator it = reflection->MapBegin( + message, self->parent_field_descriptor); + it != reflection->MapEnd(message, self->parent_field_descriptor); + ++it) { + key.reset(MapKeyToPython(self, it.GetKey())); + if (key == NULL) { + return NULL; + } + value.reset(MapValueRefToPython(self, it.GetValueRef())); + if (value == NULL) { + return NULL; + } + if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) { + return NULL; + } + } + return PyObject_Repr(dict.get()); +} + +static void ScalarMapDealloc(PyObject* _self) { + MapContainer* self = GetMap(_self); + self->RemoveFromParentCache(); + PyTypeObject *type = Py_TYPE(_self); + type->tp_free(_self); + if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) { + // With Python3, the Map class is not static, and must be managed. + Py_DECREF(type); + } +} + +static PyMethodDef ScalarMapMethods[] = { + {"__contains__", MapReflectionFriend::Contains, METH_O, + "Tests whether a key is a member of the map."}, + {"clear", (PyCFunction)Clear, METH_NOARGS, + "Removes all elements from the map."}, + {"get", (PyCFunction)ScalarMapGet, METH_VARARGS | METH_KEYWORDS, + "Gets the value for the given key if present, or otherwise a default"}, + {"GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS, + "Return the class used to build Entries of (key, value) pairs."}, + {"MergeFrom", (PyCFunction)MapReflectionFriend::MergeFrom, METH_O, + "Merges a map into the current map."}, + /* + { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, + "Makes a deep copy of the class." }, + { "__reduce__", (PyCFunction)Reduce, METH_NOARGS, + "Outputs picklable representation of the repeated field." }, + */ + {NULL, NULL}, +}; + +PyTypeObject *ScalarMapContainer_Type; +#if PY_MAJOR_VERSION >= 3 + static PyType_Slot ScalarMapContainer_Type_slots[] = { + {Py_tp_dealloc, (void *)ScalarMapDealloc}, + {Py_mp_length, (void *)MapReflectionFriend::Length}, + {Py_mp_subscript, (void *)MapReflectionFriend::ScalarMapGetItem}, + {Py_mp_ass_subscript, (void *)MapReflectionFriend::ScalarMapSetItem}, + {Py_tp_methods, (void *)ScalarMapMethods}, + {Py_tp_iter, (void *)MapReflectionFriend::GetIterator}, + {Py_tp_repr, (void *)MapReflectionFriend::ScalarMapToStr}, + {0, 0}, + }; + + PyType_Spec ScalarMapContainer_Type_spec = { + FULL_MODULE_NAME ".ScalarMapContainer", + sizeof(MapContainer), + 0, + Py_TPFLAGS_DEFAULT, + ScalarMapContainer_Type_slots + }; +#else + static PyMappingMethods ScalarMapMappingMethods = { + MapReflectionFriend::Length, // mp_length + MapReflectionFriend::ScalarMapGetItem, // mp_subscript + MapReflectionFriend::ScalarMapSetItem, // mp_ass_subscript + }; + + PyTypeObject _ScalarMapContainer_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".ScalarMapContainer", // tp_name + sizeof(MapContainer), // tp_basicsize + 0, // tp_itemsize + ScalarMapDealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + MapReflectionFriend::ScalarMapToStr, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + &ScalarMapMappingMethods, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A scalar map container", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + MapReflectionFriend::GetIterator, // tp_iter + 0, // tp_iternext + ScalarMapMethods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + }; +#endif + + +// MessageMap ////////////////////////////////////////////////////////////////// + +static MessageMapContainer* GetMessageMap(PyObject* obj) { + return reinterpret_cast<MessageMapContainer*>(obj); +} + +static PyObject* GetCMessage(MessageMapContainer* self, Message* message) { + // Get or create the CMessage object corresponding to this message. + return self->parent + ->BuildSubMessageFromPointer(self->parent_field_descriptor, message, + self->message_class) + ->AsPyObject(); +} + +MessageMapContainer* NewMessageMapContainer( + CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor, + CMessageClass* message_class) { + if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) { + return NULL; + } + + PyObject* obj = PyType_GenericAlloc(MessageMapContainer_Type, 0); + if (obj == NULL) { + PyErr_SetString(PyExc_RuntimeError, "Could not allocate new container."); + return NULL; + } + + MessageMapContainer* self = GetMessageMap(obj); + + Py_INCREF(parent); + self->parent = parent; + self->parent_field_descriptor = parent_field_descriptor; + self->version = 0; + + Py_INCREF(message_class); + self->message_class = message_class; + + return self; +} + +int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key, + PyObject* v) { + if (v) { + PyErr_Format(PyExc_ValueError, + "Direct assignment of submessage not allowed"); + return -1; + } + + // Now we know that this is a delete, not a set. + + MessageMapContainer* self = GetMessageMap(_self); + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + MapKey map_key; + MapValueRef value; + + self->version++; + + if (!PythonToMapKey(self, key, &map_key)) { + return -1; + } + + // Delete key from map. + if (reflection->ContainsMapKey(*message, self->parent_field_descriptor, + map_key)) { + // Delete key from CMessage dict. + MapValueRef value; + reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor, + map_key, &value); + Message* sub_message = value.MutableMessageValue(); + // If there is a living weak reference to an item, we "Release" it, + // otherwise we just discard the C++ value. + if (CMessage* released = + self->parent->MaybeReleaseSubMessage(sub_message)) { + Message* msg = released->message; + released->message = msg->New(); + msg->GetReflection()->Swap(msg, released->message); + } + + // Delete key from map. + reflection->DeleteMapValue(message, self->parent_field_descriptor, + map_key); + return 0; + } else { + PyErr_Format(PyExc_KeyError, "Key not present in map"); + return -1; + } +} + +PyObject* MapReflectionFriend::MessageMapGetItem(PyObject* _self, + PyObject* key) { + MessageMapContainer* self = GetMessageMap(_self); + + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + MapKey map_key; + MapValueRef value; + + if (!PythonToMapKey(self, key, &map_key)) { + return NULL; + } + + if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor, + map_key, &value)) { + self->version++; + } + + return GetCMessage(self, value.MutableMessageValue()); +} + +PyObject* MapReflectionFriend::MessageMapToStr(PyObject* _self) { + ScopedPyObjectPtr dict(PyDict_New()); + if (dict == NULL) { + return NULL; + } + ScopedPyObjectPtr key; + ScopedPyObjectPtr value; + + MessageMapContainer* self = GetMessageMap(_self); + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + for (google::protobuf::MapIterator it = reflection->MapBegin( + message, self->parent_field_descriptor); + it != reflection->MapEnd(message, self->parent_field_descriptor); + ++it) { + key.reset(MapKeyToPython(self, it.GetKey())); + if (key == NULL) { + return NULL; + } + value.reset(GetCMessage(self, it.MutableValueRef()->MutableMessageValue())); + if (value == NULL) { + return NULL; + } + if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) { + return NULL; + } + } + return PyObject_Repr(dict.get()); +} + +PyObject* MessageMapGet(PyObject* self, PyObject* args, PyObject* kwargs) { + static const char* kwlist[] = {"key", "default", nullptr}; + PyObject* key; + PyObject* default_value = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O", + const_cast<char**>(kwlist), &key, + &default_value)) { + return NULL; + } + + ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key)); + if (is_present.get() == NULL) { + return NULL; + } + + if (PyObject_IsTrue(is_present.get())) { + return MapReflectionFriend::MessageMapGetItem(self, key); + } else { + if (default_value != NULL) { + Py_INCREF(default_value); + return default_value; + } else { + Py_RETURN_NONE; + } + } +} + +static void MessageMapDealloc(PyObject* _self) { + MessageMapContainer* self = GetMessageMap(_self); + self->RemoveFromParentCache(); + Py_DECREF(self->message_class); + PyTypeObject *type = Py_TYPE(_self); + type->tp_free(_self); + if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) { + // With Python3, the Map class is not static, and must be managed. + Py_DECREF(type); + } +} + +static PyMethodDef MessageMapMethods[] = { + {"__contains__", (PyCFunction)MapReflectionFriend::Contains, METH_O, + "Tests whether the map contains this element."}, + {"clear", (PyCFunction)Clear, METH_NOARGS, + "Removes all elements from the map."}, + {"get", (PyCFunction)MessageMapGet, METH_VARARGS | METH_KEYWORDS, + "Gets the value for the given key if present, or otherwise a default"}, + {"get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O, + "Alias for getitem, useful to make explicit that the map is mutated."}, + {"GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS, + "Return the class used to build Entries of (key, value) pairs."}, + {"MergeFrom", (PyCFunction)MapReflectionFriend::MergeFrom, METH_O, + "Merges a map into the current map."}, + /* + { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, + "Makes a deep copy of the class." }, + { "__reduce__", (PyCFunction)Reduce, METH_NOARGS, + "Outputs picklable representation of the repeated field." }, + */ + {NULL, NULL}, +}; + +PyTypeObject *MessageMapContainer_Type; +#if PY_MAJOR_VERSION >= 3 + static PyType_Slot MessageMapContainer_Type_slots[] = { + {Py_tp_dealloc, (void *)MessageMapDealloc}, + {Py_mp_length, (void *)MapReflectionFriend::Length}, + {Py_mp_subscript, (void *)MapReflectionFriend::MessageMapGetItem}, + {Py_mp_ass_subscript, (void *)MapReflectionFriend::MessageMapSetItem}, + {Py_tp_methods, (void *)MessageMapMethods}, + {Py_tp_iter, (void *)MapReflectionFriend::GetIterator}, + {Py_tp_repr, (void *)MapReflectionFriend::MessageMapToStr}, + {0, 0} + }; + + PyType_Spec MessageMapContainer_Type_spec = { + FULL_MODULE_NAME ".MessageMapContainer", + sizeof(MessageMapContainer), + 0, + Py_TPFLAGS_DEFAULT, + MessageMapContainer_Type_slots + }; +#else + static PyMappingMethods MessageMapMappingMethods = { + MapReflectionFriend::Length, // mp_length + MapReflectionFriend::MessageMapGetItem, // mp_subscript + MapReflectionFriend::MessageMapSetItem, // mp_ass_subscript + }; + + PyTypeObject _MessageMapContainer_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".MessageMapContainer", // tp_name + sizeof(MessageMapContainer), // tp_basicsize + 0, // tp_itemsize + MessageMapDealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + MapReflectionFriend::MessageMapToStr, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + &MessageMapMappingMethods, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A map container for message", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + MapReflectionFriend::GetIterator, // tp_iter + 0, // tp_iternext + MessageMapMethods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + }; +#endif + +// MapIterator ///////////////////////////////////////////////////////////////// + +static MapIterator* GetIter(PyObject* obj) { + return reinterpret_cast<MapIterator*>(obj); +} + +PyObject* MapReflectionFriend::GetIterator(PyObject *_self) { + MapContainer* self = GetMap(_self); + + ScopedPyObjectPtr obj(PyType_GenericAlloc(&MapIterator_Type, 0)); + if (obj == NULL) { + return PyErr_Format(PyExc_KeyError, "Could not allocate iterator"); + } + + MapIterator* iter = GetIter(obj.get()); + + Py_INCREF(self); + iter->container = self; + iter->version = self->version; + Py_INCREF(self->parent); + iter->parent = self->parent; + + if (MapReflectionFriend::Length(_self) > 0) { + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + + iter->iter.reset(new ::google::protobuf::MapIterator( + reflection->MapBegin(message, self->parent_field_descriptor))); + } + + return obj.release(); +} + +PyObject* MapReflectionFriend::IterNext(PyObject* _self) { + MapIterator* self = GetIter(_self); + + // This won't catch mutations to the map performed by MergeFrom(); no easy way + // to address that. + if (self->version != self->container->version) { + return PyErr_Format(PyExc_RuntimeError, + "Map modified during iteration."); + } + if (self->parent != self->container->parent) { + return PyErr_Format(PyExc_RuntimeError, + "Map cleared during iteration."); + } + + if (self->iter.get() == NULL) { + return NULL; + } + + Message* message = self->container->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + + if (*self->iter == + reflection->MapEnd(message, self->container->parent_field_descriptor)) { + return NULL; + } + + PyObject* ret = MapKeyToPython(self->container, self->iter->GetKey()); + + ++(*self->iter); + + return ret; +} + +static void DeallocMapIterator(PyObject* _self) { + MapIterator* self = GetIter(_self); + self->iter.reset(); + Py_CLEAR(self->container); + Py_CLEAR(self->parent); + Py_TYPE(_self)->tp_free(_self); +} + +PyTypeObject MapIterator_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".MapIterator", // tp_name + sizeof(MapIterator), // tp_basicsize + 0, // tp_itemsize + DeallocMapIterator, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A scalar map iterator", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + PyObject_SelfIter, // tp_iter + MapReflectionFriend::IterNext, // tp_iternext + 0, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init +}; + +bool InitMapContainers() { + // ScalarMapContainer_Type derives from our MutableMapping type. + ScopedPyObjectPtr containers(PyImport_ImportModule( + "google.protobuf.internal.containers")); + if (containers == NULL) { + return false; + } + + ScopedPyObjectPtr mutable_mapping( + PyObject_GetAttrString(containers.get(), "MutableMapping")); + if (mutable_mapping == NULL) { + return false; + } + + Py_INCREF(mutable_mapping.get()); +#if PY_MAJOR_VERSION >= 3 + ScopedPyObjectPtr bases(PyTuple_Pack(1, mutable_mapping.get())); + if (bases == NULL) { + return false; + } + + ScalarMapContainer_Type = reinterpret_cast<PyTypeObject*>( + PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases.get())); +#else + _ScalarMapContainer_Type.tp_base = + reinterpret_cast<PyTypeObject*>(mutable_mapping.get()); + + if (PyType_Ready(&_ScalarMapContainer_Type) < 0) { + return false; + } + + ScalarMapContainer_Type = &_ScalarMapContainer_Type; +#endif + + if (PyType_Ready(&MapIterator_Type) < 0) { + return false; + } + +#if PY_MAJOR_VERSION >= 3 + MessageMapContainer_Type = reinterpret_cast<PyTypeObject*>( + PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases.get())); +#else + Py_INCREF(mutable_mapping.get()); + _MessageMapContainer_Type.tp_base = + reinterpret_cast<PyTypeObject*>(mutable_mapping.get()); + + if (PyType_Ready(&_MessageMapContainer_Type) < 0) { + return false; + } + + MessageMapContainer_Type = &_MessageMapContainer_Type; +#endif + return true; +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/map_container.h b/contrib/python/protobuf/py3/google/protobuf/pyext/map_container.h new file mode 100644 index 0000000000..842602e79f --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/map_container.h @@ -0,0 +1,89 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_MAP_CONTAINER_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_MAP_CONTAINER_H__ + +#include <Python.h> + +#include <cstdint> +#include <memory> + +#include <google/protobuf/descriptor.h> +#include <google/protobuf/message.h> +#include <google/protobuf/pyext/message.h> + +namespace google { +namespace protobuf { + +class Message; + +namespace python { + +struct CMessageClass; + +// This struct is used directly for ScalarMap, and is the base class of +// MessageMapContainer, which is used for MessageMap. +struct MapContainer : public ContainerBase { + // Use to get a mutable message when necessary. + Message* GetMutableMessage(); + + // We bump this whenever we perform a mutation, to invalidate existing + // iterators. + uint64_t version; +}; + +struct MessageMapContainer : public MapContainer { + // The type used to create new child messages. + CMessageClass* message_class; +}; + +bool InitMapContainers(); + +extern PyTypeObject* MessageMapContainer_Type; +extern PyTypeObject* ScalarMapContainer_Type; +extern PyTypeObject MapIterator_Type; // Both map types use the same iterator. + +// Builds a MapContainer object, from a parent message and a +// field descriptor. +extern MapContainer* NewScalarMapContainer( + CMessage* parent, const FieldDescriptor* parent_field_descriptor); + +// Builds a MessageMap object, from a parent message and a +// field descriptor. +extern MessageMapContainer* NewMessageMapContainer( + CMessage* parent, const FieldDescriptor* parent_field_descriptor, + CMessageClass* message_class); + +} // namespace python +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MAP_CONTAINER_H__ diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/message.cc b/contrib/python/protobuf/py3/google/protobuf/pyext/message.cc new file mode 100644 index 0000000000..8b41ca47dd --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/message.cc @@ -0,0 +1,3122 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#include <google/protobuf/pyext/message.h> + +#include <structmember.h> // A Python header file. + +#include <cstdint> +#include <map> +#include <memory> +#include <string> +#include <vector> + +#include <google/protobuf/stubs/strutil.h> + +#ifndef PyVarObject_HEAD_INIT +#define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size, +#endif +#ifndef Py_TYPE +#define Py_TYPE(ob) (((PyObject*)(ob))->ob_type) +#endif +#include <google/protobuf/stubs/common.h> +#include <google/protobuf/stubs/logging.h> +#include <google/protobuf/io/coded_stream.h> +#include <google/protobuf/io/zero_copy_stream_impl_lite.h> +#include <google/protobuf/descriptor.pb.h> +#include <google/protobuf/descriptor.h> +#include <google/protobuf/message.h> +#include <google/protobuf/text_format.h> +#include <google/protobuf/unknown_field_set.h> +#include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/descriptor_pool.h> +#include <google/protobuf/pyext/extension_dict.h> +#include <google/protobuf/pyext/field.h> +#include <google/protobuf/pyext/map_container.h> +#include <google/protobuf/pyext/message_factory.h> +#include <google/protobuf/pyext/repeated_composite_container.h> +#include <google/protobuf/pyext/repeated_scalar_container.h> +#include <google/protobuf/pyext/safe_numerics.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> +#include <google/protobuf/pyext/unknown_fields.h> +#include <google/protobuf/util/message_differencer.h> +#include <google/protobuf/io/strtod.h> +#include <google/protobuf/stubs/map_util.h> + +// clang-format off +#include <google/protobuf/port_def.inc> +// clang-format on + +#if PY_MAJOR_VERSION >= 3 + #define PyInt_AsLong PyLong_AsLong + #define PyInt_FromLong PyLong_FromLong + #define PyInt_FromSize_t PyLong_FromSize_t + #define PyString_Check PyUnicode_Check + #define PyString_FromString PyUnicode_FromString + #define PyString_FromStringAndSize PyUnicode_FromStringAndSize + #define PyString_FromFormat PyUnicode_FromFormat + #if PY_VERSION_HEX < 0x03030000 + #error "Python 3.0 - 3.2 are not supported." + #else + #define PyString_AsString(ob) \ + (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AsString(ob)) +#define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>( \ + PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \ + ? -1 \ + : 0) \ + : PyBytes_AsStringAndSize(ob, (charpp), (sizep))) +#endif +#endif + +namespace google { +namespace protobuf { +namespace python { + +class MessageReflectionFriend { + public: + static void UnsafeShallowSwapFields( + Message* lhs, Message* rhs, + const std::vector<const FieldDescriptor*>& fields) { + lhs->GetReflection()->UnsafeShallowSwapFields(lhs, rhs, fields); + } +}; + +static PyObject* kDESCRIPTOR; +PyObject* EnumTypeWrapper_class; +static PyObject* PythonMessage_class; +static PyObject* kEmptyWeakref; +static PyObject* WKT_classes = NULL; + +namespace message_meta { + +static int InsertEmptyWeakref(PyTypeObject* base); + +namespace { +// Copied over from internal 'google/protobuf/stubs/strutil.h'. +inline void LowerString(TProtoStringType* s) { + s->to_lower(); +} +} // namespace + +// Finalize the creation of the Message class. +static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) { + // For each field set: cls.<field>_FIELD_NUMBER = <number> + for (int i = 0; i < descriptor->field_count(); ++i) { + const FieldDescriptor* field_descriptor = descriptor->field(i); + ScopedPyObjectPtr property(NewFieldProperty(field_descriptor)); + if (property == NULL) { + return -1; + } + if (PyObject_SetAttrString(cls, field_descriptor->name().c_str(), + property.get()) < 0) { + return -1; + } + } + + // For each enum set cls.<enum name> = EnumTypeWrapper(<enum descriptor>). + for (int i = 0; i < descriptor->enum_type_count(); ++i) { + const EnumDescriptor* enum_descriptor = descriptor->enum_type(i); + ScopedPyObjectPtr enum_type( + PyEnumDescriptor_FromDescriptor(enum_descriptor)); + if (enum_type == NULL) { + return -1; + } + // Add wrapped enum type to message class. + ScopedPyObjectPtr wrapped(PyObject_CallFunctionObjArgs( + EnumTypeWrapper_class, enum_type.get(), NULL)); + if (wrapped == NULL) { + return -1; + } + if (PyObject_SetAttrString( + cls, enum_descriptor->name().c_str(), wrapped.get()) == -1) { + return -1; + } + + // For each enum value add cls.<name> = <number> + for (int j = 0; j < enum_descriptor->value_count(); ++j) { + const EnumValueDescriptor* enum_value_descriptor = + enum_descriptor->value(j); + ScopedPyObjectPtr value_number(PyInt_FromLong( + enum_value_descriptor->number())); + if (value_number == NULL) { + return -1; + } + if (PyObject_SetAttrString(cls, enum_value_descriptor->name().c_str(), + value_number.get()) == -1) { + return -1; + } + } + } + + // For each extension set cls.<extension name> = <extension descriptor>. + // + // Extension descriptors come from + // <message descriptor>.extensions_by_name[name] + // which was defined previously. + for (int i = 0; i < descriptor->extension_count(); ++i) { + const google::protobuf::FieldDescriptor* field = descriptor->extension(i); + ScopedPyObjectPtr extension_field(PyFieldDescriptor_FromDescriptor(field)); + if (extension_field == NULL) { + return -1; + } + + // Add the extension field to the message class. + if (PyObject_SetAttrString( + cls, field->name().c_str(), extension_field.get()) == -1) { + return -1; + } + } + + return 0; +} + +static PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) { + static const char* kwlist[] = {"name", "bases", "dict", 0}; + PyObject *bases, *dict; + const char* name; + + // Check arguments: (name, bases, dict) + if (!PyArg_ParseTupleAndKeywords( + args, kwargs, "sO!O!:type", const_cast<char**>(kwlist), &name, + &PyTuple_Type, &bases, &PyDict_Type, &dict)) { + return NULL; + } + + // Check bases: only (), or (message.Message,) are allowed + if (!(PyTuple_GET_SIZE(bases) == 0 || + (PyTuple_GET_SIZE(bases) == 1 && + PyTuple_GET_ITEM(bases, 0) == PythonMessage_class))) { + PyErr_SetString(PyExc_TypeError, + "A Message class can only inherit from Message"); + return NULL; + } + + // Check dict['DESCRIPTOR'] + PyObject* py_descriptor = PyDict_GetItem(dict, kDESCRIPTOR); + if (py_descriptor == nullptr) { + PyErr_SetString(PyExc_TypeError, "Message class has no DESCRIPTOR"); + return nullptr; + } + if (!PyObject_TypeCheck(py_descriptor, &PyMessageDescriptor_Type)) { + PyErr_Format(PyExc_TypeError, "Expected a message Descriptor, got %s", + py_descriptor->ob_type->tp_name); + return nullptr; + } + const Descriptor* message_descriptor = + PyMessageDescriptor_AsDescriptor(py_descriptor); + if (message_descriptor == nullptr) { + return nullptr; + } + + // Messages have no __dict__ + ScopedPyObjectPtr slots(PyTuple_New(0)); + if (PyDict_SetItemString(dict, "__slots__", slots.get()) < 0) { + return NULL; + } + + // Build the arguments to the base metaclass. + // We change the __bases__ classes. + ScopedPyObjectPtr new_args; + + if (WKT_classes == NULL) { + ScopedPyObjectPtr well_known_types(PyImport_ImportModule( + "google.protobuf.internal.well_known_types")); + GOOGLE_DCHECK(well_known_types != NULL); + + WKT_classes = PyObject_GetAttrString(well_known_types.get(), "WKTBASES"); + GOOGLE_DCHECK(WKT_classes != NULL); + } + + PyObject* well_known_class = PyDict_GetItemString( + WKT_classes, message_descriptor->full_name().c_str()); + if (well_known_class == NULL) { + new_args.reset(Py_BuildValue("s(OO)O", name, CMessage_Type, + PythonMessage_class, dict)); + } else { + new_args.reset(Py_BuildValue("s(OOO)O", name, CMessage_Type, + PythonMessage_class, well_known_class, dict)); + } + + if (new_args == NULL) { + return NULL; + } + // Call the base metaclass. + ScopedPyObjectPtr result(PyType_Type.tp_new(type, new_args.get(), NULL)); + if (result == NULL) { + return NULL; + } + CMessageClass* newtype = reinterpret_cast<CMessageClass*>(result.get()); + + // Insert the empty weakref into the base classes. + if (InsertEmptyWeakref( + reinterpret_cast<PyTypeObject*>(PythonMessage_class)) < 0 || + InsertEmptyWeakref(CMessage_Type) < 0) { + return NULL; + } + + // Cache the descriptor, both as Python object and as C++ pointer. + const Descriptor* descriptor = + PyMessageDescriptor_AsDescriptor(py_descriptor); + if (descriptor == NULL) { + return NULL; + } + Py_INCREF(py_descriptor); + newtype->py_message_descriptor = py_descriptor; + newtype->message_descriptor = descriptor; + // TODO(amauryfa): Don't always use the canonical pool of the descriptor, + // use the MessageFactory optionally passed in the class dict. + PyDescriptorPool* py_descriptor_pool = + GetDescriptorPool_FromPool(descriptor->file()->pool()); + if (py_descriptor_pool == NULL) { + return NULL; + } + newtype->py_message_factory = py_descriptor_pool->py_message_factory; + Py_INCREF(newtype->py_message_factory); + + // Register the message in the MessageFactory. + // TODO(amauryfa): Move this call to MessageFactory.GetPrototype() when the + // MessageFactory is fully implemented in C++. + if (message_factory::RegisterMessageClass(newtype->py_message_factory, + descriptor, newtype) < 0) { + return NULL; + } + + // Continue with type initialization: add other descriptors, enum values... + if (AddDescriptors(result.get(), descriptor) < 0) { + return NULL; + } + return result.release(); +} + +static void Dealloc(PyObject* pself) { + CMessageClass* self = reinterpret_cast<CMessageClass*>(pself); + Py_XDECREF(self->py_message_descriptor); + Py_XDECREF(self->py_message_factory); + return PyType_Type.tp_dealloc(pself); +} + +static int GcTraverse(PyObject* pself, visitproc visit, void* arg) { + CMessageClass* self = reinterpret_cast<CMessageClass*>(pself); + Py_VISIT(self->py_message_descriptor); + Py_VISIT(self->py_message_factory); + return PyType_Type.tp_traverse(pself, visit, arg); +} + +static int GcClear(PyObject* pself) { + // It's important to keep the descriptor and factory alive, until the + // C++ message is fully destructed. + return PyType_Type.tp_clear(pself); +} + +// This function inserts and empty weakref at the end of the list of +// subclasses for the main protocol buffer Message class. +// +// This eliminates a O(n^2) behaviour in the internal add_subclass +// routine. +static int InsertEmptyWeakref(PyTypeObject *base_type) { +#if PY_MAJOR_VERSION >= 3 + // Python 3.4 has already included the fix for the issue that this + // hack addresses. For further background and the fix please see + // https://bugs.python.org/issue17936. + return 0; +#else +#ifdef Py_DEBUG + // The code below causes all new subclasses to append an entry, which is never + // cleared. This is a small memory leak, which we disable in Py_DEBUG mode + // to have stable refcounting checks. +#else + PyObject *subclasses = base_type->tp_subclasses; + if (subclasses && PyList_CheckExact(subclasses)) { + return PyList_Append(subclasses, kEmptyWeakref); + } +#endif // !Py_DEBUG + return 0; +#endif // PY_MAJOR_VERSION >= 3 +} + +// The _extensions_by_name dictionary is built on every access. +// TODO(amauryfa): Migrate all users to pool.FindAllExtensions() +static PyObject* GetExtensionsByName(CMessageClass *self, void *closure) { + if (self->message_descriptor == NULL) { + // This is the base Message object, simply raise AttributeError. + PyErr_SetString(PyExc_AttributeError, + "Base Message class has no DESCRIPTOR"); + return NULL; + } + + const PyDescriptorPool* pool = self->py_message_factory->pool; + + std::vector<const FieldDescriptor*> extensions; + pool->pool->FindAllExtensions(self->message_descriptor, &extensions); + + ScopedPyObjectPtr result(PyDict_New()); + for (int i = 0; i < extensions.size(); i++) { + ScopedPyObjectPtr extension( + PyFieldDescriptor_FromDescriptor(extensions[i])); + if (extension == NULL) { + return NULL; + } + if (PyDict_SetItemString(result.get(), extensions[i]->full_name().c_str(), + extension.get()) < 0) { + return NULL; + } + } + return result.release(); +} + +// The _extensions_by_number dictionary is built on every access. +// TODO(amauryfa): Migrate all users to pool.FindExtensionByNumber() +static PyObject* GetExtensionsByNumber(CMessageClass *self, void *closure) { + if (self->message_descriptor == NULL) { + // This is the base Message object, simply raise AttributeError. + PyErr_SetString(PyExc_AttributeError, + "Base Message class has no DESCRIPTOR"); + return NULL; + } + + const PyDescriptorPool* pool = self->py_message_factory->pool; + + std::vector<const FieldDescriptor*> extensions; + pool->pool->FindAllExtensions(self->message_descriptor, &extensions); + + ScopedPyObjectPtr result(PyDict_New()); + for (int i = 0; i < extensions.size(); i++) { + ScopedPyObjectPtr extension( + PyFieldDescriptor_FromDescriptor(extensions[i])); + if (extension == NULL) { + return NULL; + } + ScopedPyObjectPtr number(PyInt_FromLong(extensions[i]->number())); + if (number == NULL) { + return NULL; + } + if (PyDict_SetItem(result.get(), number.get(), extension.get()) < 0) { + return NULL; + } + } + return result.release(); +} + +static PyGetSetDef Getters[] = { + {"_extensions_by_name", (getter)GetExtensionsByName, NULL}, + {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL}, + {NULL} +}; + +// Compute some class attributes on the fly: +// - All the _FIELD_NUMBER attributes, for all fields and nested extensions. +// Returns a new reference, or NULL with an exception set. +static PyObject* GetClassAttribute(CMessageClass *self, PyObject* name) { + char* attr; + Py_ssize_t attr_size; + static const char kSuffix[] = "_FIELD_NUMBER"; + if (PyString_AsStringAndSize(name, &attr, &attr_size) >= 0 && + HasSuffixString(StringPiece(attr, attr_size), kSuffix)) { + TProtoStringType field_name(attr, attr_size - sizeof(kSuffix) + 1); + LowerString(&field_name); + + // Try to find a field with the given name, without the suffix. + const FieldDescriptor* field = + self->message_descriptor->FindFieldByLowercaseName(field_name); + if (!field) { + // Search nested extensions as well. + field = + self->message_descriptor->FindExtensionByLowercaseName(field_name); + } + if (field) { + return PyInt_FromLong(field->number()); + } + } + PyErr_SetObject(PyExc_AttributeError, name); + return NULL; +} + +static PyObject* GetAttr(CMessageClass* self, PyObject* name) { + PyObject* result = CMessageClass_Type->tp_base->tp_getattro( + reinterpret_cast<PyObject*>(self), name); + if (result != NULL) { + return result; + } + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { + return NULL; + } + + PyErr_Clear(); + return GetClassAttribute(self, name); +} + +} // namespace message_meta + +static PyTypeObject _CMessageClass_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME + ".MessageMeta", // tp_name + sizeof(CMessageClass), // tp_basicsize + 0, // tp_itemsize + message_meta::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + (getattrofunc)message_meta::GetAttr, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, // tp_flags + "The metaclass of ProtocolMessages", // tp_doc + message_meta::GcTraverse, // tp_traverse + message_meta::GcClear, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + 0, // tp_methods + 0, // tp_members + message_meta::Getters, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + message_meta::New, // tp_new +}; +PyTypeObject* CMessageClass_Type = &_CMessageClass_Type; + +static CMessageClass* CheckMessageClass(PyTypeObject* cls) { + if (!PyObject_TypeCheck(cls, CMessageClass_Type)) { + PyErr_Format(PyExc_TypeError, "Class %s is not a Message", cls->tp_name); + return NULL; + } + return reinterpret_cast<CMessageClass*>(cls); +} + +static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) { + CMessageClass* type = CheckMessageClass(cls); + if (type == NULL) { + return NULL; + } + return type->message_descriptor; +} + +// Forward declarations +namespace cmessage { +int InternalReleaseFieldByDescriptor( + CMessage* self, + const FieldDescriptor* field_descriptor); +} // namespace cmessage + +// --------------------------------------------------------------------- + +PyObject* EncodeError_class; +PyObject* DecodeError_class; +PyObject* PickleError_class; + +// Format an error message for unexpected types. +// Always return with an exception set. +void FormatTypeError(PyObject* arg, const char* expected_types) { + // This function is often called with an exception set. + // Clear it to call PyObject_Repr() in good conditions. + PyErr_Clear(); + PyObject* repr = PyObject_Repr(arg); + if (repr) { + PyErr_Format(PyExc_TypeError, + "%.100s has type %.100s, but expected one of: %s", + PyString_AsString(repr), + Py_TYPE(arg)->tp_name, + expected_types); + Py_DECREF(repr); + } +} + +void OutOfRangeError(PyObject* arg) { + PyObject *s = PyObject_Str(arg); + if (s) { + PyErr_Format(PyExc_ValueError, + "Value out of range: %s", + PyString_AsString(s)); + Py_DECREF(s); + } +} + +template<class RangeType, class ValueType> +bool VerifyIntegerCastAndRange(PyObject* arg, ValueType value) { + if (PROTOBUF_PREDICT_FALSE(value == -1 && PyErr_Occurred())) { + if (PyErr_ExceptionMatches(PyExc_OverflowError)) { + // Replace it with the same ValueError as pure python protos instead of + // the default one. + PyErr_Clear(); + OutOfRangeError(arg); + } // Otherwise propagate existing error. + return false; + } + if (PROTOBUF_PREDICT_FALSE(!IsValidNumericCast<RangeType>(value))) { + OutOfRangeError(arg); + return false; + } + return true; +} + +template <class T> +bool CheckAndGetInteger(PyObject* arg, T* value) { + // The fast path. +#if PY_MAJOR_VERSION < 3 + // For the typical case, offer a fast path. + if (PROTOBUF_PREDICT_TRUE(PyInt_Check(arg))) { + long int_result = PyInt_AsLong(arg); + if (PROTOBUF_PREDICT_TRUE(IsValidNumericCast<T>(int_result))) { + *value = static_cast<T>(int_result); + return true; + } else { + OutOfRangeError(arg); + return false; + } + } +#endif + // This effectively defines an integer as "an object that can be cast as + // an integer and can be used as an ordinal number". + // This definition includes everything that implements numbers.Integral + // and shouldn't cast the net too wide. + if (PROTOBUF_PREDICT_FALSE(!PyIndex_Check(arg))) { + FormatTypeError(arg, "int, long"); + return false; + } + + // Now we have an integral number so we can safely use PyLong_ functions. + // We need to treat the signed and unsigned cases differently in case arg is + // holding a value above the maximum for signed longs. + if (std::numeric_limits<T>::min() == 0) { + // Unsigned case. + unsigned PY_LONG_LONG ulong_result; + if (PyLong_Check(arg)) { + ulong_result = PyLong_AsUnsignedLongLong(arg); + } else { + // Unlike PyLong_AsLongLong, PyLong_AsUnsignedLongLong is very + // picky about the exact type. + PyObject* casted = PyNumber_Long(arg); + if (PROTOBUF_PREDICT_FALSE(casted == nullptr)) { + // Propagate existing error. + return false; + } + ulong_result = PyLong_AsUnsignedLongLong(casted); + Py_DECREF(casted); + } + if (VerifyIntegerCastAndRange<T, unsigned PY_LONG_LONG>(arg, + ulong_result)) { + *value = static_cast<T>(ulong_result); + } else { + return false; + } + } else { + // Signed case. + PY_LONG_LONG long_result; + PyNumberMethods *nb; + if ((nb = arg->ob_type->tp_as_number) != NULL && nb->nb_int != NULL) { + // PyLong_AsLongLong requires it to be a long or to have an __int__() + // method. + long_result = PyLong_AsLongLong(arg); + } else { + // Valid subclasses of numbers.Integral should have a __long__() method + // so fall back to that. + PyObject* casted = PyNumber_Long(arg); + if (PROTOBUF_PREDICT_FALSE(casted == nullptr)) { + // Propagate existing error. + return false; + } + long_result = PyLong_AsLongLong(casted); + Py_DECREF(casted); + } + if (VerifyIntegerCastAndRange<T, PY_LONG_LONG>(arg, long_result)) { + *value = static_cast<T>(long_result); + } else { + return false; + } + } + + return true; +} + +// These are referenced by repeated_scalar_container, and must +// be explicitly instantiated. +template bool CheckAndGetInteger<int32>(PyObject*, int32*); +template bool CheckAndGetInteger<int64>(PyObject*, int64*); +template bool CheckAndGetInteger<uint32>(PyObject*, uint32*); +template bool CheckAndGetInteger<uint64>(PyObject*, uint64*); + +bool CheckAndGetDouble(PyObject* arg, double* value) { + *value = PyFloat_AsDouble(arg); + if (PROTOBUF_PREDICT_FALSE(*value == -1 && PyErr_Occurred())) { + FormatTypeError(arg, "int, long, float"); + return false; + } + return true; +} + +bool CheckAndGetFloat(PyObject* arg, float* value) { + double double_value; + if (!CheckAndGetDouble(arg, &double_value)) { + return false; + } + *value = io::SafeDoubleToFloat(double_value); + return true; +} + +bool CheckAndGetBool(PyObject* arg, bool* value) { + long long_value = PyInt_AsLong(arg); + if (long_value == -1 && PyErr_Occurred()) { + FormatTypeError(arg, "int, long, bool"); + return false; + } + *value = static_cast<bool>(long_value); + + return true; +} + +// Checks whether the given object (which must be "bytes" or "unicode") contains +// valid UTF-8. +bool IsValidUTF8(PyObject* obj) { + if (PyBytes_Check(obj)) { + PyObject* unicode = PyUnicode_FromEncodedObject(obj, "utf-8", NULL); + + // Clear the error indicator; we report our own error when desired. + PyErr_Clear(); + + if (unicode) { + Py_DECREF(unicode); + return true; + } else { + return false; + } + } else { + // Unicode object, known to be valid UTF-8. + return true; + } +} + +bool AllowInvalidUTF8(const FieldDescriptor* field) { return false; } + +PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor) { + GOOGLE_DCHECK(descriptor->type() == FieldDescriptor::TYPE_STRING || + descriptor->type() == FieldDescriptor::TYPE_BYTES); + if (descriptor->type() == FieldDescriptor::TYPE_STRING) { + if (!PyBytes_Check(arg) && !PyUnicode_Check(arg)) { + FormatTypeError(arg, "bytes, unicode"); + return NULL; + } + + if (!IsValidUTF8(arg) && !AllowInvalidUTF8(descriptor)) { + PyObject* repr = PyObject_Repr(arg); + PyErr_Format(PyExc_ValueError, + "%s has type str, but isn't valid UTF-8 " + "encoding. Non-UTF-8 strings must be converted to " + "unicode objects before being added.", + PyString_AsString(repr)); + Py_DECREF(repr); + return NULL; + } + } else if (!PyBytes_Check(arg)) { + FormatTypeError(arg, "bytes"); + return NULL; + } + + PyObject* encoded_string = NULL; + if (descriptor->type() == FieldDescriptor::TYPE_STRING) { + if (PyBytes_Check(arg)) { + // The bytes were already validated as correctly encoded UTF-8 above. + encoded_string = arg; // Already encoded. + Py_INCREF(encoded_string); + } else { + encoded_string = PyUnicode_AsEncodedString(arg, "utf-8", NULL); + } + } else { + // In this case field type is "bytes". + encoded_string = arg; + Py_INCREF(encoded_string); + } + + return encoded_string; +} + +bool CheckAndSetString( + PyObject* arg, Message* message, + const FieldDescriptor* descriptor, + const Reflection* reflection, + bool append, + int index) { + ScopedPyObjectPtr encoded_string(CheckString(arg, descriptor)); + + if (encoded_string.get() == NULL) { + return false; + } + + char* value; + Py_ssize_t value_len; + if (PyBytes_AsStringAndSize(encoded_string.get(), &value, &value_len) < 0) { + return false; + } + + TProtoStringType value_string(value, value_len); + if (append) { + reflection->AddString(message, descriptor, std::move(value_string)); + } else if (index < 0) { + reflection->SetString(message, descriptor, std::move(value_string)); + } else { + reflection->SetRepeatedString(message, descriptor, index, + std::move(value_string)); + } + return true; +} + +PyObject* ToStringObject(const FieldDescriptor* descriptor, + const TProtoStringType& value) { +#if PY_MAJOR_VERSION >= 3 + if (descriptor->type() != FieldDescriptor::TYPE_STRING) { + return PyBytes_FromStringAndSize(value.c_str(), value.length()); + } + + PyObject* result = PyUnicode_DecodeUTF8(value.c_str(), value.length(), NULL); + // If the string can't be decoded in UTF-8, just return a string object that + // contains the raw bytes. This can't happen if the value was assigned using + // the members of the Python message object, but can happen if the values were + // parsed from the wire (binary). + if (result == NULL) { + PyErr_Clear(); + result = PyBytes_FromStringAndSize(value.c_str(), value.length()); + } + return result; +#else + return PyBytes_FromStringAndSize(value.c_str(), value.length()); +#endif +} + +bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor, + const Message* message) { + if (message->GetDescriptor() == field_descriptor->containing_type()) { + return true; + } + PyErr_Format(PyExc_KeyError, "Field '%s' does not belong to message '%s'", + field_descriptor->full_name().c_str(), + message->GetDescriptor()->full_name().c_str()); + return false; +} + +namespace cmessage { + +PyMessageFactory* GetFactoryForMessage(CMessage* message) { + GOOGLE_DCHECK(PyObject_TypeCheck(message, CMessage_Type)); + return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_message_factory; +} + +static int MaybeReleaseOverlappingOneofField( + CMessage* cmessage, + const FieldDescriptor* field) { +#ifdef GOOGLE_PROTOBUF_HAS_ONEOF + Message* message = cmessage->message; + const Reflection* reflection = message->GetReflection(); + if (!field->containing_oneof() || + !reflection->HasOneof(*message, field->containing_oneof()) || + reflection->HasField(*message, field)) { + // No other field in this oneof, no need to release. + return 0; + } + + const OneofDescriptor* oneof = field->containing_oneof(); + const FieldDescriptor* existing_field = + reflection->GetOneofFieldDescriptor(*message, oneof); + if (existing_field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) { + // Non-message fields don't need to be released. + return 0; + } + if (InternalReleaseFieldByDescriptor(cmessage, existing_field) < 0) { + return -1; + } +#endif + return 0; +} + +// After a Merge, visit every sub-message that was read-only, and +// eventually update their pointer if the Merge operation modified them. +int FixupMessageAfterMerge(CMessage* self) { + if (!self->composite_fields) { + return 0; + } + for (const auto& item : *self->composite_fields) { + const FieldDescriptor* descriptor = item.first; + if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE && + !descriptor->is_repeated()) { + CMessage* cmsg = reinterpret_cast<CMessage*>(item.second); + if (cmsg->read_only == false) { + return 0; + } + Message* message = self->message; + const Reflection* reflection = message->GetReflection(); + if (reflection->HasField(*message, descriptor)) { + // Message used to be read_only, but is no longer. Get the new pointer + // and record it. + Message* mutable_message = + reflection->MutableMessage(message, descriptor, nullptr); + cmsg->message = mutable_message; + cmsg->read_only = false; + if (FixupMessageAfterMerge(cmsg) < 0) { + return -1; + } + } + } + } + + return 0; +} + +// --------------------------------------------------------------------- +// Making a message writable + +int AssureWritable(CMessage* self) { + if (self == NULL || !self->read_only) { + return 0; + } + + // Toplevel messages are always mutable. + GOOGLE_DCHECK(self->parent); + + if (AssureWritable(self->parent) == -1) { + return -1; + } + // If this message is part of a oneof, there might be a field to release in + // the parent. + if (MaybeReleaseOverlappingOneofField(self->parent, + self->parent_field_descriptor) < 0) { + return -1; + } + + // Make self->message writable. + Message* parent_message = self->parent->message; + const Reflection* reflection = parent_message->GetReflection(); + Message* mutable_message = reflection->MutableMessage( + parent_message, self->parent_field_descriptor, + GetFactoryForMessage(self->parent)->message_factory); + if (mutable_message == NULL) { + return -1; + } + self->message = mutable_message; + self->read_only = false; + + return 0; +} + +// --- Globals: + +// Retrieve a C++ FieldDescriptor for an extension handle. +const FieldDescriptor* GetExtensionDescriptor(PyObject* extension) { + ScopedPyObjectPtr cdescriptor; + if (!PyObject_TypeCheck(extension, &PyFieldDescriptor_Type)) { + // Most callers consider extensions as a plain dictionary. We should + // allow input which is not a field descriptor, and simply pretend it does + // not exist. + PyErr_SetObject(PyExc_KeyError, extension); + return NULL; + } + return PyFieldDescriptor_AsDescriptor(extension); +} + +// If value is a string, convert it into an enum value based on the labels in +// descriptor, otherwise simply return value. Always returns a new reference. +static PyObject* GetIntegerEnumValue(const FieldDescriptor& descriptor, + PyObject* value) { + if (PyString_Check(value) || PyUnicode_Check(value)) { + const EnumDescriptor* enum_descriptor = descriptor.enum_type(); + if (enum_descriptor == NULL) { + PyErr_SetString(PyExc_TypeError, "not an enum field"); + return NULL; + } + char* enum_label; + Py_ssize_t size; + if (PyString_AsStringAndSize(value, &enum_label, &size) < 0) { + return NULL; + } + const EnumValueDescriptor* enum_value_descriptor = + enum_descriptor->FindValueByName(StringParam(enum_label, size)); + if (enum_value_descriptor == NULL) { + PyErr_Format(PyExc_ValueError, "unknown enum label \"%s\"", enum_label); + return NULL; + } + return PyInt_FromLong(enum_value_descriptor->number()); + } + Py_INCREF(value); + return value; +} + +// Delete a slice from a repeated field. +// The only way to remove items in C++ protos is to delete the last one, +// so we swap items to move the deleted ones at the end, and then strip the +// sequence. +int DeleteRepeatedField( + CMessage* self, + const FieldDescriptor* field_descriptor, + PyObject* slice) { + Py_ssize_t length, from, to, step, slice_length; + Message* message = self->message; + const Reflection* reflection = message->GetReflection(); + int min, max; + length = reflection->FieldSize(*message, field_descriptor); + + if (PySlice_Check(slice)) { + from = to = step = slice_length = 0; +#if PY_MAJOR_VERSION < 3 + PySlice_GetIndicesEx( + reinterpret_cast<PySliceObject*>(slice), + length, &from, &to, &step, &slice_length); +#else + PySlice_GetIndicesEx( + slice, + length, &from, &to, &step, &slice_length); +#endif + if (from < to) { + min = from; + max = to - 1; + } else { + min = to + 1; + max = from; + } + } else { + from = to = PyLong_AsLong(slice); + if (from == -1 && PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, "list indices must be integers"); + return -1; + } + + if (from < 0) { + from = to = length + from; + } + step = 1; + min = max = from; + + // Range check. + if (from < 0 || from >= length) { + PyErr_Format(PyExc_IndexError, "list assignment index out of range"); + return -1; + } + } + + Py_ssize_t i = from; + std::vector<bool> to_delete(length, false); + while (i >= min && i <= max) { + to_delete[i] = true; + i += step; + } + + // Swap elements so that items to delete are at the end. + to = 0; + for (i = 0; i < length; ++i) { + if (!to_delete[i]) { + if (i != to) { + reflection->SwapElements(message, field_descriptor, i, to); + } + ++to; + } + } + + // Remove items, starting from the end. + for (; length > to; length--) { + if (field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) { + reflection->RemoveLast(message, field_descriptor); + continue; + } + // It seems that RemoveLast() is less efficient for sub-messages, and + // the memory is not completely released. Prefer ReleaseLast(). + Message* sub_message = reflection->ReleaseLast(message, field_descriptor); + // If there is a live weak reference to an item being removed, we "Release" + // it, and it takes ownership of the message. + if (CMessage* released = self->MaybeReleaseSubMessage(sub_message)) { + released->message = sub_message; + } else { + // sub_message was not transferred, delete it. + delete sub_message; + } + } + + return 0; +} + +// Initializes fields of a message. Used in constructors. +int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) { + if (args != NULL && PyTuple_Size(args) != 0) { + PyErr_SetString(PyExc_TypeError, "No positional arguments allowed"); + return -1; + } + + if (kwargs == NULL) { + return 0; + } + + Py_ssize_t pos = 0; + PyObject* name; + PyObject* value; + while (PyDict_Next(kwargs, &pos, &name, &value)) { + if (!(PyString_Check(name) || PyUnicode_Check(name))) { + PyErr_SetString(PyExc_ValueError, "Field name must be a string"); + return -1; + } + ScopedPyObjectPtr property( + PyObject_GetAttr(reinterpret_cast<PyObject*>(Py_TYPE(self)), name)); + if (property == NULL || + !PyObject_TypeCheck(property.get(), CFieldProperty_Type)) { + PyErr_Format(PyExc_ValueError, "Protocol message %s has no \"%s\" field.", + self->message->GetDescriptor()->name().c_str(), + PyString_AsString(name)); + return -1; + } + const FieldDescriptor* descriptor = + reinterpret_cast<PyMessageFieldProperty*>(property.get()) + ->field_descriptor; + if (value == Py_None) { + // field=None is the same as no field at all. + continue; + } + if (descriptor->is_map()) { + ScopedPyObjectPtr map(GetFieldValue(self, descriptor)); + const FieldDescriptor* value_descriptor = + descriptor->message_type()->FindFieldByName("value"); + if (value_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + ScopedPyObjectPtr iter(PyObject_GetIter(value)); + if (iter == NULL) { + PyErr_Format(PyExc_TypeError, "Argument %s is not iterable", PyString_AsString(name)); + return -1; + } + ScopedPyObjectPtr next; + while ((next.reset(PyIter_Next(iter.get()))) != NULL) { + ScopedPyObjectPtr source_value(PyObject_GetItem(value, next.get())); + ScopedPyObjectPtr dest_value(PyObject_GetItem(map.get(), next.get())); + if (source_value.get() == NULL || dest_value.get() == NULL) { + return -1; + } + ScopedPyObjectPtr ok(PyObject_CallMethod( + dest_value.get(), "MergeFrom", "O", source_value.get())); + if (ok.get() == NULL) { + return -1; + } + } + } else { + ScopedPyObjectPtr function_return; + function_return.reset( + PyObject_CallMethod(map.get(), "update", "O", value)); + if (function_return.get() == NULL) { + return -1; + } + } + } else if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { + ScopedPyObjectPtr container(GetFieldValue(self, descriptor)); + if (container == NULL) { + return -1; + } + if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + RepeatedCompositeContainer* rc_container = + reinterpret_cast<RepeatedCompositeContainer*>(container.get()); + ScopedPyObjectPtr iter(PyObject_GetIter(value)); + if (iter == NULL) { + PyErr_SetString(PyExc_TypeError, "Value must be iterable"); + return -1; + } + ScopedPyObjectPtr next; + while ((next.reset(PyIter_Next(iter.get()))) != NULL) { + PyObject* kwargs = (PyDict_Check(next.get()) ? next.get() : NULL); + ScopedPyObjectPtr new_msg( + repeated_composite_container::Add(rc_container, NULL, kwargs)); + if (new_msg == NULL) { + return -1; + } + if (kwargs == NULL) { + // next was not a dict, it's a message we need to merge + ScopedPyObjectPtr merged(MergeFrom( + reinterpret_cast<CMessage*>(new_msg.get()), next.get())); + if (merged.get() == NULL) { + return -1; + } + } + } + if (PyErr_Occurred()) { + // Check to see how PyIter_Next() exited. + return -1; + } + } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { + RepeatedScalarContainer* rs_container = + reinterpret_cast<RepeatedScalarContainer*>(container.get()); + ScopedPyObjectPtr iter(PyObject_GetIter(value)); + if (iter == NULL) { + PyErr_SetString(PyExc_TypeError, "Value must be iterable"); + return -1; + } + ScopedPyObjectPtr next; + while ((next.reset(PyIter_Next(iter.get()))) != NULL) { + ScopedPyObjectPtr enum_value( + GetIntegerEnumValue(*descriptor, next.get())); + if (enum_value == NULL) { + return -1; + } + ScopedPyObjectPtr new_msg(repeated_scalar_container::Append( + rs_container, enum_value.get())); + if (new_msg == NULL) { + return -1; + } + } + if (PyErr_Occurred()) { + // Check to see how PyIter_Next() exited. + return -1; + } + } else { + if (ScopedPyObjectPtr(repeated_scalar_container::Extend( + reinterpret_cast<RepeatedScalarContainer*>(container.get()), + value)) == + NULL) { + return -1; + } + } + } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + ScopedPyObjectPtr message(GetFieldValue(self, descriptor)); + if (message == NULL) { + return -1; + } + CMessage* cmessage = reinterpret_cast<CMessage*>(message.get()); + if (PyDict_Check(value)) { + // Make the message exist even if the dict is empty. + AssureWritable(cmessage); + if (InitAttributes(cmessage, NULL, value) < 0) { + return -1; + } + } else { + ScopedPyObjectPtr merged(MergeFrom(cmessage, value)); + if (merged == NULL) { + return -1; + } + } + } else { + ScopedPyObjectPtr new_val; + if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { + new_val.reset(GetIntegerEnumValue(*descriptor, value)); + if (new_val == NULL) { + return -1; + } + value = new_val.get(); + } + if (SetFieldValue(self, descriptor, value) < 0) { + return -1; + } + } + } + return 0; +} + +// Allocates an incomplete Python Message: the caller must fill self->message +// and eventually self->parent. +CMessage* NewEmptyMessage(CMessageClass* type) { + CMessage* self = reinterpret_cast<CMessage*>( + PyType_GenericAlloc(&type->super.ht_type, 0)); + if (self == NULL) { + return NULL; + } + + self->message = NULL; + self->parent = NULL; + self->parent_field_descriptor = NULL; + self->read_only = false; + + self->composite_fields = NULL; + self->child_submessages = NULL; + + self->unknown_field_set = NULL; + + return self; +} + +// The __new__ method of Message classes. +// Creates a new C++ message and takes ownership. +static CMessage* NewCMessage(CMessageClass* type) { + // Retrieve the message descriptor and the default instance (=prototype). + const Descriptor* message_descriptor = type->message_descriptor; + if (message_descriptor == nullptr) { + // This would be very unexpected since the CMessageClass has already + // been checked. + PyErr_Format(PyExc_TypeError, + "CMessageClass object '%s' has no descriptor.", + Py_TYPE(type)->tp_name); + return nullptr; + } + const Message* prototype = + type->py_message_factory->message_factory->GetPrototype( + message_descriptor); + if (prototype == nullptr) { + PyErr_SetString(PyExc_TypeError, message_descriptor->full_name().c_str()); + return nullptr; + } + + CMessage* self = NewEmptyMessage(type); + if (self == nullptr) { + return nullptr; + } + self->message = prototype->New(nullptr); // Ensures no arena is used. + self->parent = nullptr; // This message owns its data. + return self; +} + +static PyObject* New(PyTypeObject* cls, PyObject* unused_args, + PyObject* unused_kwargs) { + CMessageClass* type = CheckMessageClass(cls); + if (type == nullptr) { + return nullptr; + } + return reinterpret_cast<PyObject*>(NewCMessage(type)); +} + +// The __init__ method of Message classes. +// It initializes fields from keywords passed to the constructor. +static int Init(CMessage* self, PyObject* args, PyObject* kwargs) { + return InitAttributes(self, args, kwargs); +} + +// --------------------------------------------------------------------- +// Deallocating a CMessage + +static void Dealloc(CMessage* self) { + if (self->weakreflist) { + PyObject_ClearWeakRefs(reinterpret_cast<PyObject*>(self)); + } + // At this point all dependent objects have been removed. + GOOGLE_DCHECK(!self->child_submessages || self->child_submessages->empty()); + GOOGLE_DCHECK(!self->composite_fields || self->composite_fields->empty()); + delete self->child_submessages; + delete self->composite_fields; + if (self->unknown_field_set) { + unknown_fields::Clear( + reinterpret_cast<PyUnknownFields*>(self->unknown_field_set)); + } + + CMessage* parent = self->parent; + if (!parent) { + // No parent, we own the message. + delete self->message; + } else if (parent->AsPyObject() == Py_None) { + // Message owned externally: Nothing to dealloc + Py_CLEAR(self->parent); + } else { + // Clear this message from its parent's map. + if (self->parent_field_descriptor->is_repeated()) { + if (parent->child_submessages) + parent->child_submessages->erase(self->message); + } else { + if (parent->composite_fields) + parent->composite_fields->erase(self->parent_field_descriptor); + } + Py_CLEAR(self->parent); + } + Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); +} + +// --------------------------------------------------------------------- + + +PyObject* IsInitialized(CMessage* self, PyObject* args) { + PyObject* errors = NULL; + if (!PyArg_ParseTuple(args, "|O", &errors)) { + return NULL; + } + if (self->message->IsInitialized()) { + Py_RETURN_TRUE; + } + if (errors != NULL) { + ScopedPyObjectPtr initialization_errors( + FindInitializationErrors(self)); + if (initialization_errors == NULL) { + return NULL; + } + ScopedPyObjectPtr extend_name(PyString_FromString("extend")); + if (extend_name == NULL) { + return NULL; + } + ScopedPyObjectPtr result(PyObject_CallMethodObjArgs( + errors, + extend_name.get(), + initialization_errors.get(), + NULL)); + if (result == NULL) { + return NULL; + } + } + Py_RETURN_FALSE; +} + +int HasFieldByDescriptor(CMessage* self, + const FieldDescriptor* field_descriptor) { + Message* message = self->message; + if (!CheckFieldBelongsToMessage(field_descriptor, message)) { + return -1; + } + if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { + PyErr_SetString(PyExc_KeyError, + "Field is repeated. A singular method is required."); + return -1; + } + return message->GetReflection()->HasField(*message, field_descriptor); +} + +const FieldDescriptor* FindFieldWithOneofs(const Message* message, + ConstStringParam field_name, + bool* in_oneof) { + *in_oneof = false; + const Descriptor* descriptor = message->GetDescriptor(); + const FieldDescriptor* field_descriptor = + descriptor->FindFieldByName(field_name); + if (field_descriptor != NULL) { + return field_descriptor; + } + const OneofDescriptor* oneof_desc = + descriptor->FindOneofByName(field_name); + if (oneof_desc != NULL) { + *in_oneof = true; + return message->GetReflection()->GetOneofFieldDescriptor(*message, + oneof_desc); + } + return NULL; +} + +bool CheckHasPresence(const FieldDescriptor* field_descriptor, bool in_oneof) { + auto message_name = field_descriptor->containing_type()->name(); + if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { + PyErr_Format(PyExc_ValueError, + "Protocol message %s has no singular \"%s\" field.", + message_name.c_str(), field_descriptor->name().c_str()); + return false; + } + + if (!field_descriptor->has_presence()) { + PyErr_Format(PyExc_ValueError, + "Can't test non-optional, non-submessage field \"%s.%s\" for " + "presence in proto3.", + message_name.c_str(), field_descriptor->name().c_str()); + return false; + } + + return true; +} + +PyObject* HasField(CMessage* self, PyObject* arg) { + char* field_name; + Py_ssize_t size; +#if PY_MAJOR_VERSION < 3 + if (PyString_AsStringAndSize(arg, &field_name, &size) < 0) { + return NULL; + } +#else + field_name = const_cast<char*>(PyUnicode_AsUTF8AndSize(arg, &size)); + if (!field_name) { + return NULL; + } +#endif + + Message* message = self->message; + bool is_in_oneof; + const FieldDescriptor* field_descriptor = + FindFieldWithOneofs(message, StringParam(field_name, size), &is_in_oneof); + if (field_descriptor == NULL) { + if (!is_in_oneof) { + PyErr_Format(PyExc_ValueError, "Protocol message %s has no field %s.", + message->GetDescriptor()->name().c_str(), field_name); + return NULL; + } else { + Py_RETURN_FALSE; + } + } + + if (!CheckHasPresence(field_descriptor, is_in_oneof)) { + return NULL; + } + + if (message->GetReflection()->HasField(*message, field_descriptor)) { + Py_RETURN_TRUE; + } + + Py_RETURN_FALSE; +} + +PyObject* ClearExtension(CMessage* self, PyObject* extension) { + const FieldDescriptor* descriptor = GetExtensionDescriptor(extension); + if (descriptor == NULL) { + return NULL; + } + if (ClearFieldByDescriptor(self, descriptor) < 0) { + return nullptr; + } + Py_RETURN_NONE; +} + +PyObject* HasExtension(CMessage* self, PyObject* extension) { + const FieldDescriptor* descriptor = GetExtensionDescriptor(extension); + if (descriptor == NULL) { + return NULL; + } + int has_field = HasFieldByDescriptor(self, descriptor); + if (has_field < 0) { + return nullptr; + } else { + return PyBool_FromLong(has_field); + } +} + +// --------------------------------------------------------------------- +// Releasing messages +// +// The Python API's ClearField() and Clear() methods behave +// differently than their C++ counterparts. While the C++ versions +// clears the children, the Python versions detaches the children, +// without touching their content. This impedance mismatch causes +// some complexity in the implementation, which is captured in this +// section. +// +// When one or multiple fields are cleared we need to: +// +// * Gather all child objects that need to be detached from the message. +// In composite_fields and child_submessages. +// +// * Create a new Python message of the same kind. Use SwapFields() to move +// data from the original message. +// +// * Change the parent of all child objects: update their strong reference +// to their parent, and move their presence in composite_fields and +// child_submessages. + +// --------------------------------------------------------------------- +// Release a composite child of a CMessage + +static int InternalReparentFields( + CMessage* self, const std::vector<CMessage*>& messages_to_release, + const std::vector<ContainerBase*>& containers_to_release) { + if (messages_to_release.empty() && containers_to_release.empty()) { + return 0; + } + + // Move all the passed sub_messages to another message. + CMessage* new_message = cmessage::NewEmptyMessage(self->GetMessageClass()); + if (new_message == nullptr) { + return -1; + } + new_message->message = self->message->New(nullptr); + ScopedPyObjectPtr holder(reinterpret_cast<PyObject*>(new_message)); + new_message->child_submessages = new CMessage::SubMessagesMap(); + new_message->composite_fields = new CMessage::CompositeFieldsMap(); + std::set<const FieldDescriptor*> fields_to_swap; + + // In case this the removed fields are the last reference to a message, keep + // a reference. + Py_INCREF(self); + + for (const auto& to_release : messages_to_release) { + fields_to_swap.insert(to_release->parent_field_descriptor); + // Reparent + Py_INCREF(new_message); + Py_DECREF(to_release->parent); + to_release->parent = new_message; + self->child_submessages->erase(to_release->message); + new_message->child_submessages->emplace(to_release->message, to_release); + } + + for (const auto& to_release : containers_to_release) { + fields_to_swap.insert(to_release->parent_field_descriptor); + Py_INCREF(new_message); + Py_DECREF(to_release->parent); + to_release->parent = new_message; + self->composite_fields->erase(to_release->parent_field_descriptor); + new_message->composite_fields->emplace(to_release->parent_field_descriptor, + to_release); + } + + if (self->message->GetArena() == new_message->message->GetArena()) { + MessageReflectionFriend::UnsafeShallowSwapFields( + self->message, new_message->message, + std::vector<const FieldDescriptor*>(fields_to_swap.begin(), + fields_to_swap.end())); + } else { + self->message->GetReflection()->SwapFields( + self->message, new_message->message, + std::vector<const FieldDescriptor*>(fields_to_swap.begin(), + fields_to_swap.end())); + } + + // This might delete the Python message completely if all children were moved. + Py_DECREF(self); + + return 0; +} + +int InternalReleaseFieldByDescriptor( + CMessage* self, + const FieldDescriptor* field_descriptor) { + if (!field_descriptor->is_repeated() && + field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) { + // Single scalars are not in any cache. + return 0; + } + std::vector<CMessage*> messages_to_release; + std::vector<ContainerBase*> containers_to_release; + if (self->child_submessages && field_descriptor->is_repeated() && + field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + for (const auto& child_item : *self->child_submessages) { + if (child_item.second->parent_field_descriptor == field_descriptor) { + messages_to_release.push_back(child_item.second); + } + } + } + if (self->composite_fields) { + CMessage::CompositeFieldsMap::iterator it = + self->composite_fields->find(field_descriptor); + if (it != self->composite_fields->end()) { + containers_to_release.push_back(it->second); + } + } + + return InternalReparentFields(self, messages_to_release, + containers_to_release); +} + +int ClearFieldByDescriptor(CMessage* self, + const FieldDescriptor* field_descriptor) { + if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) { + return -1; + } + if (InternalReleaseFieldByDescriptor(self, field_descriptor) < 0) { + return -1; + } + AssureWritable(self); + Message* message = self->message; + message->GetReflection()->ClearField(message, field_descriptor); + return 0; +} + +PyObject* ClearField(CMessage* self, PyObject* arg) { + char* field_name; + Py_ssize_t field_size; + if (PyString_AsStringAndSize(arg, &field_name, &field_size) < 0) { + return NULL; + } + AssureWritable(self); + bool is_in_oneof; + const FieldDescriptor* field_descriptor = FindFieldWithOneofs( + self->message, StringParam(field_name, field_size), &is_in_oneof); + if (field_descriptor == NULL) { + if (is_in_oneof) { + // We gave the name of a oneof, and none of its fields are set. + Py_RETURN_NONE; + } else { + PyErr_Format(PyExc_ValueError, + "Protocol message has no \"%s\" field.", field_name); + return NULL; + } + } + + if (ClearFieldByDescriptor(self, field_descriptor) < 0) { + return nullptr; + } + Py_RETURN_NONE; +} + +PyObject* Clear(CMessage* self) { + AssureWritable(self); + // Detach all current fields of this message + std::vector<CMessage*> messages_to_release; + std::vector<ContainerBase*> containers_to_release; + if (self->child_submessages) { + for (const auto& item : *self->child_submessages) { + messages_to_release.push_back(item.second); + } + } + if (self->composite_fields) { + for (const auto& item : *self->composite_fields) { + containers_to_release.push_back(item.second); + } + } + if (InternalReparentFields(self, messages_to_release, containers_to_release) < + 0) { + return NULL; + } + if (self->unknown_field_set) { + unknown_fields::Clear( + reinterpret_cast<PyUnknownFields*>(self->unknown_field_set)); + self->unknown_field_set = nullptr; + } + self->message->Clear(); + Py_RETURN_NONE; +} + +// --------------------------------------------------------------------- + +static TProtoStringType GetMessageName(CMessage* self) { + if (self->parent_field_descriptor != NULL) { + return self->parent_field_descriptor->full_name(); + } else { + return self->message->GetDescriptor()->full_name(); + } +} + +static PyObject* InternalSerializeToString( + CMessage* self, PyObject* args, PyObject* kwargs, + bool require_initialized) { + // Parse the "deterministic" kwarg; defaults to False. + static const char* kwlist[] = {"deterministic", 0}; + PyObject* deterministic_obj = Py_None; + if (!PyArg_ParseTupleAndKeywords( + args, kwargs, "|O", const_cast<char**>(kwlist), &deterministic_obj)) { + return NULL; + } + // Preemptively convert to a bool first, so we don't need to back out of + // allocating memory if this raises an exception. + // NOTE: This is unused later if deterministic == Py_None, but that's fine. + int deterministic = PyObject_IsTrue(deterministic_obj); + if (deterministic < 0) { + return NULL; + } + + if (require_initialized && !self->message->IsInitialized()) { + ScopedPyObjectPtr errors(FindInitializationErrors(self)); + if (errors == NULL) { + return NULL; + } + ScopedPyObjectPtr comma(PyString_FromString(",")); + if (comma == NULL) { + return NULL; + } + ScopedPyObjectPtr joined( + PyObject_CallMethod(comma.get(), "join", "O", errors.get())); + if (joined == NULL) { + return NULL; + } + + // TODO(haberman): this is a (hopefully temporary) hack. The unit testing + // infrastructure reloads all pure-Python modules for every test, but not + // C++ modules (because that's generally impossible: + // http://bugs.python.org/issue1144263). But if we cache EncodeError, we'll + // return the EncodeError from a previous load of the module, which won't + // match a user's attempt to catch EncodeError. So we have to look it up + // again every time. + ScopedPyObjectPtr message_module(PyImport_ImportModule( + "google.protobuf.message")); + if (message_module.get() == NULL) { + return NULL; + } + + ScopedPyObjectPtr encode_error( + PyObject_GetAttrString(message_module.get(), "EncodeError")); + if (encode_error.get() == NULL) { + return NULL; + } + PyErr_Format(encode_error.get(), + "Message %s is missing required fields: %s", + GetMessageName(self).c_str(), PyString_AsString(joined.get())); + return NULL; + } + + // Ok, arguments parsed and errors checked, now encode to a string + const size_t size = self->message->ByteSizeLong(); + if (size == 0) { + return PyBytes_FromString(""); + } + + if (size > INT_MAX) { + PyErr_Format(PyExc_ValueError, + "Message %s exceeds maximum protobuf " + "size of 2GB: %zu", + GetMessageName(self).c_str(), size); + return nullptr; + } + + PyObject* result = PyBytes_FromStringAndSize(NULL, size); + if (result == NULL) { + return NULL; + } + io::ArrayOutputStream out(PyBytes_AS_STRING(result), size); + io::CodedOutputStream coded_out(&out); + if (deterministic_obj != Py_None) { + coded_out.SetSerializationDeterministic(deterministic); + } + self->message->SerializeWithCachedSizes(&coded_out); + GOOGLE_CHECK(!coded_out.HadError()); + return result; +} + +static PyObject* SerializeToString( + CMessage* self, PyObject* args, PyObject* kwargs) { + return InternalSerializeToString(self, args, kwargs, + /*require_initialized=*/true); +} + +static PyObject* SerializePartialToString( + CMessage* self, PyObject* args, PyObject* kwargs) { + return InternalSerializeToString(self, args, kwargs, + /*require_initialized=*/false); +} + +// Formats proto fields for ascii dumps using python formatting functions where +// appropriate. +class PythonFieldValuePrinter : public TextFormat::FastFieldValuePrinter { + public: + // Python has some differences from C++ when printing floating point numbers. + // + // 1) Trailing .0 is always printed. + // 2) (Python2) Output is rounded to 12 digits. + // 3) (Python3) The full precision of the double is preserved (and Python uses + // David M. Gay's dtoa(), when the C++ code uses SimpleDtoa. There are some + // differences, but they rarely happen) + // + // We override floating point printing with the C-API function for printing + // Python floats to ensure consistency. + void PrintFloat(float val, + TextFormat::BaseTextGenerator* generator) const override { + PrintDouble(val, generator); + } + void PrintDouble(double val, + TextFormat::BaseTextGenerator* generator) const override { + // This implementation is not highly optimized (it allocates two temporary + // Python objects) but it is simple and portable. If this is shown to be a + // performance bottleneck, we can optimize it, but the results will likely + // be more complicated to accommodate the differing behavior of double + // formatting between Python 2 and Python 3. + // + // (Though a valid question is: do we really want to make out output + // dependent on the Python version?) + ScopedPyObjectPtr py_value(PyFloat_FromDouble(val)); + if (!py_value.get()) { + return; + } + + ScopedPyObjectPtr py_str(PyObject_Str(py_value.get())); + if (!py_str.get()) { + return; + } + + generator->PrintString(PyString_AsString(py_str.get())); + } +}; + +static PyObject* ToStr(CMessage* self) { + TextFormat::Printer printer; + // Passes ownership + printer.SetDefaultFieldValuePrinter(new PythonFieldValuePrinter()); + printer.SetHideUnknownFields(true); + TProtoStringType output; + if (!printer.PrintToString(*self->message, &output)) { + PyErr_SetString(PyExc_ValueError, "Unable to convert message to str"); + return NULL; + } + return PyString_FromString(output.c_str()); +} + +PyObject* MergeFrom(CMessage* self, PyObject* arg) { + CMessage* other_message; + if (!PyObject_TypeCheck(arg, CMessage_Type)) { + PyErr_Format(PyExc_TypeError, + "Parameter to MergeFrom() must be instance of same class: " + "expected %s got %s.", + self->message->GetDescriptor()->full_name().c_str(), + Py_TYPE(arg)->tp_name); + return NULL; + } + + other_message = reinterpret_cast<CMessage*>(arg); + if (other_message->message->GetDescriptor() != + self->message->GetDescriptor()) { + PyErr_Format(PyExc_TypeError, + "Parameter to MergeFrom() must be instance of same class: " + "expected %s got %s.", + self->message->GetDescriptor()->full_name().c_str(), + other_message->message->GetDescriptor()->full_name().c_str()); + return NULL; + } + AssureWritable(self); + + self->message->MergeFrom(*other_message->message); + // Child message might be lazily created before MergeFrom. Make sure they + // are mutable at this point if child messages are really created. + if (FixupMessageAfterMerge(self) < 0) { + return NULL; + } + + Py_RETURN_NONE; +} + +static PyObject* CopyFrom(CMessage* self, PyObject* arg) { + CMessage* other_message; + if (!PyObject_TypeCheck(arg, CMessage_Type)) { + PyErr_Format(PyExc_TypeError, + "Parameter to CopyFrom() must be instance of same class: " + "expected %s got %s.", + self->message->GetDescriptor()->full_name().c_str(), + Py_TYPE(arg)->tp_name); + return NULL; + } + + other_message = reinterpret_cast<CMessage*>(arg); + + if (self == other_message) { + Py_RETURN_NONE; + } + + if (other_message->message->GetDescriptor() != + self->message->GetDescriptor()) { + PyErr_Format(PyExc_TypeError, + "Parameter to CopyFrom() must be instance of same class: " + "expected %s got %s.", + self->message->GetDescriptor()->full_name().c_str(), + other_message->message->GetDescriptor()->full_name().c_str()); + return NULL; + } + + AssureWritable(self); + + // CopyFrom on the message will not clean up self->composite_fields, + // which can leave us in an inconsistent state, so clear it out here. + (void)ScopedPyObjectPtr(Clear(self)); + + self->message->CopyFrom(*other_message->message); + + Py_RETURN_NONE; +} + +// Protobuf has a 64MB limit built in, this variable will override this. Please +// do not enable this unless you fully understand the implications: protobufs +// must all be kept in memory at the same time, so if they grow too big you may +// get OOM errors. The protobuf APIs do not provide any tools for processing +// protobufs in chunks. If you have protos this big you should break them up if +// it is at all convenient to do so. +#ifdef PROTOBUF_PYTHON_ALLOW_OVERSIZE_PROTOS +static bool allow_oversize_protos = true; +#else +static bool allow_oversize_protos = false; +#endif + +// Provide a method in the module to set allow_oversize_protos to a boolean +// value. This method returns the newly value of allow_oversize_protos. +PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg) { + if (!arg || !PyBool_Check(arg)) { + PyErr_SetString(PyExc_TypeError, + "Argument to SetAllowOversizeProtos must be boolean"); + return NULL; + } + allow_oversize_protos = PyObject_IsTrue(arg); + if (allow_oversize_protos) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + +static PyObject* MergeFromString(CMessage* self, PyObject* arg) { + Py_buffer data; + if (PyObject_GetBuffer(arg, &data, PyBUF_SIMPLE) < 0) { + return NULL; + } + + AssureWritable(self); + + PyMessageFactory* factory = GetFactoryForMessage(self); + int depth = allow_oversize_protos + ? INT_MAX + : io::CodedInputStream::GetDefaultRecursionLimit(); + const char* ptr; + internal::ParseContext ctx( + depth, false, &ptr, + StringPiece(static_cast<const char*>(data.buf), data.len)); + PyBuffer_Release(&data); + ctx.data().pool = factory->pool->pool; + ctx.data().factory = factory->message_factory; + + ptr = self->message->_InternalParse(ptr, &ctx); + + // Child message might be lazily created before MergeFrom. Make sure they + // are mutable at this point if child messages are really created. + if (FixupMessageAfterMerge(self) < 0) { + return NULL; + } + + // Python makes distinction in error message, between a general parse failure + // and in-correct ending on a terminating tag. Hence we need to be a bit more + // explicit in our correctness checks. + if (ptr == nullptr || ctx.BytesUntilLimit(ptr) < 0) { + // Parse error or the parser overshoot the limit. + PyErr_Format(DecodeError_class, "Error parsing message"); + return NULL; + } + // ctx has an explicit limit set (length of string_view), so we have to + // check we ended at that limit. + if (!ctx.EndedAtLimit()) { + PyErr_Format(DecodeError_class, "Unexpected end-group tag: Not all data was converted"); + return nullptr; + } + return PyInt_FromLong(data.len); +} + +static PyObject* ParseFromString(CMessage* self, PyObject* arg) { + if (ScopedPyObjectPtr(Clear(self)) == NULL) { + return NULL; + } + return MergeFromString(self, arg); +} + +static PyObject* ByteSize(CMessage* self, PyObject* args) { + return PyLong_FromLong(self->message->ByteSizeLong()); +} + +PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle) { + const FieldDescriptor* descriptor = + GetExtensionDescriptor(extension_handle); + if (descriptor == NULL) { + return NULL; + } + if (!PyObject_TypeCheck(cls, CMessageClass_Type)) { + PyErr_Format(PyExc_TypeError, "Expected a message class, got %s", + cls->ob_type->tp_name); + return NULL; + } + CMessageClass *message_class = reinterpret_cast<CMessageClass*>(cls); + if (message_class == NULL) { + return NULL; + } + // If the extension was already registered, check that it is the same. + const FieldDescriptor* existing_extension = + message_class->py_message_factory->pool->pool->FindExtensionByNumber( + descriptor->containing_type(), descriptor->number()); + if (existing_extension != NULL && existing_extension != descriptor) { + PyErr_SetString(PyExc_ValueError, "Double registration of Extensions"); + return NULL; + } + Py_RETURN_NONE; +} + +static PyObject* SetInParent(CMessage* self, PyObject* args) { + AssureWritable(self); + Py_RETURN_NONE; +} + +static PyObject* WhichOneof(CMessage* self, PyObject* arg) { + Py_ssize_t name_size; + char *name_data; + if (PyString_AsStringAndSize(arg, &name_data, &name_size) < 0) + return NULL; + const OneofDescriptor* oneof_desc = + self->message->GetDescriptor()->FindOneofByName( + StringParam(name_data, name_size)); + if (oneof_desc == NULL) { + PyErr_Format(PyExc_ValueError, + "Protocol message has no oneof \"%s\" field.", name_data); + return NULL; + } + const FieldDescriptor* field_in_oneof = + self->message->GetReflection()->GetOneofFieldDescriptor( + *self->message, oneof_desc); + if (field_in_oneof == NULL) { + Py_RETURN_NONE; + } else { + const TProtoStringType& name = field_in_oneof->name(); + return PyString_FromStringAndSize(name.c_str(), name.size()); + } +} + +static PyObject* GetExtensionDict(CMessage* self, void *closure); + +static PyObject* ListFields(CMessage* self) { + std::vector<const FieldDescriptor*> fields; + self->message->GetReflection()->ListFields(*self->message, &fields); + + // Normally, the list will be exactly the size of the fields. + ScopedPyObjectPtr all_fields(PyList_New(fields.size())); + if (all_fields == NULL) { + return NULL; + } + + // When there are unknown extensions, the py list will *not* contain + // the field information. Thus the actual size of the py list will be + // smaller than the size of fields. Set the actual size at the end. + Py_ssize_t actual_size = 0; + for (size_t i = 0; i < fields.size(); ++i) { + ScopedPyObjectPtr t(PyTuple_New(2)); + if (t == NULL) { + return NULL; + } + + if (fields[i]->is_extension()) { + ScopedPyObjectPtr extension_field( + PyFieldDescriptor_FromDescriptor(fields[i])); + if (extension_field == NULL) { + return NULL; + } + // With C++ descriptors, the field can always be retrieved, but for + // unknown extensions which have not been imported in Python code, there + // is no message class and we cannot retrieve the value. + // TODO(amauryfa): consider building the class on the fly! + if (fields[i]->message_type() != NULL && + message_factory::GetMessageClass( + GetFactoryForMessage(self), + fields[i]->message_type()) == NULL) { + PyErr_Clear(); + continue; + } + ScopedPyObjectPtr extensions(GetExtensionDict(self, NULL)); + if (extensions == NULL) { + return NULL; + } + // 'extension' reference later stolen by PyTuple_SET_ITEM. + PyObject* extension = PyObject_GetItem( + extensions.get(), extension_field.get()); + if (extension == NULL) { + return NULL; + } + PyTuple_SET_ITEM(t.get(), 0, extension_field.release()); + // Steals reference to 'extension' + PyTuple_SET_ITEM(t.get(), 1, extension); + } else { + // Normal field + ScopedPyObjectPtr field_descriptor( + PyFieldDescriptor_FromDescriptor(fields[i])); + if (field_descriptor == NULL) { + return NULL; + } + + PyObject* field_value = GetFieldValue(self, fields[i]); + if (field_value == NULL) { + PyErr_SetString(PyExc_ValueError, fields[i]->name().c_str()); + return NULL; + } + PyTuple_SET_ITEM(t.get(), 0, field_descriptor.release()); + PyTuple_SET_ITEM(t.get(), 1, field_value); + } + PyList_SET_ITEM(all_fields.get(), actual_size, t.release()); + ++actual_size; + } + if (static_cast<size_t>(actual_size) != fields.size() && + (PyList_SetSlice(all_fields.get(), actual_size, fields.size(), NULL) < + 0)) { + return NULL; + } + return all_fields.release(); +} + +static PyObject* DiscardUnknownFields(CMessage* self) { + AssureWritable(self); + self->message->DiscardUnknownFields(); + Py_RETURN_NONE; +} + +PyObject* FindInitializationErrors(CMessage* self) { + Message* message = self->message; + std::vector<TProtoStringType> errors; + message->FindInitializationErrors(&errors); + + PyObject* error_list = PyList_New(errors.size()); + if (error_list == NULL) { + return NULL; + } + for (size_t i = 0; i < errors.size(); ++i) { + const TProtoStringType& error = errors[i]; + PyObject* error_string = PyString_FromStringAndSize( + error.c_str(), error.length()); + if (error_string == NULL) { + Py_DECREF(error_list); + return NULL; + } + PyList_SET_ITEM(error_list, i, error_string); + } + return error_list; +} + +static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) { + // Only equality comparisons are implemented. + if (opid != Py_EQ && opid != Py_NE) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + bool equals = true; + // If other is not a message, it cannot be equal. + if (!PyObject_TypeCheck(other, CMessage_Type)) { + equals = false; + } else { + // Otherwise, we have a CMessage whose message we can inspect. + const google::protobuf::Message* other_message = + reinterpret_cast<CMessage*>(other)->message; + // If messages don't have the same descriptors, they are not equal. + if (equals && + self->message->GetDescriptor() != other_message->GetDescriptor()) { + equals = false; + } + // Check the message contents. + if (equals && + !google::protobuf::util::MessageDifferencer::Equals( + *self->message, *reinterpret_cast<CMessage*>(other)->message)) { + equals = false; + } + } + + if (equals ^ (opid == Py_EQ)) { + Py_RETURN_FALSE; + } else { + Py_RETURN_TRUE; + } +} + +PyObject* InternalGetScalar(const Message* message, + const FieldDescriptor* field_descriptor) { + const Reflection* reflection = message->GetReflection(); + + if (!CheckFieldBelongsToMessage(field_descriptor, message)) { + return NULL; + } + + PyObject* result = NULL; + switch (field_descriptor->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: { + int32_t value = reflection->GetInt32(*message, field_descriptor); + result = PyInt_FromLong(value); + break; + } + case FieldDescriptor::CPPTYPE_INT64: { + int64_t value = reflection->GetInt64(*message, field_descriptor); + result = PyLong_FromLongLong(value); + break; + } + case FieldDescriptor::CPPTYPE_UINT32: { + uint32_t value = reflection->GetUInt32(*message, field_descriptor); + result = PyInt_FromSize_t(value); + break; + } + case FieldDescriptor::CPPTYPE_UINT64: { + uint64_t value = reflection->GetUInt64(*message, field_descriptor); + result = PyLong_FromUnsignedLongLong(value); + break; + } + case FieldDescriptor::CPPTYPE_FLOAT: { + float value = reflection->GetFloat(*message, field_descriptor); + result = PyFloat_FromDouble(value); + break; + } + case FieldDescriptor::CPPTYPE_DOUBLE: { + double value = reflection->GetDouble(*message, field_descriptor); + result = PyFloat_FromDouble(value); + break; + } + case FieldDescriptor::CPPTYPE_BOOL: { + bool value = reflection->GetBool(*message, field_descriptor); + result = PyBool_FromLong(value); + break; + } + case FieldDescriptor::CPPTYPE_STRING: { + TProtoStringType scratch; + const TProtoStringType& value = + reflection->GetStringReference(*message, field_descriptor, &scratch); + result = ToStringObject(field_descriptor, value); + break; + } + case FieldDescriptor::CPPTYPE_ENUM: { + const EnumValueDescriptor* enum_value = + message->GetReflection()->GetEnum(*message, field_descriptor); + result = PyInt_FromLong(enum_value->number()); + break; + } + default: + PyErr_Format( + PyExc_SystemError, "Getting a value from a field of unknown type %d", + field_descriptor->cpp_type()); + } + + return result; +} + +CMessage* InternalGetSubMessage( + CMessage* self, const FieldDescriptor* field_descriptor) { + const Reflection* reflection = self->message->GetReflection(); + PyMessageFactory* factory = GetFactoryForMessage(self); + const Message& sub_message = reflection->GetMessage( + *self->message, field_descriptor, factory->message_factory); + + CMessageClass* message_class = message_factory::GetOrCreateMessageClass( + factory, field_descriptor->message_type()); + ScopedPyObjectPtr message_class_owner( + reinterpret_cast<PyObject*>(message_class)); + if (message_class == NULL) { + return NULL; + } + + CMessage* cmsg = cmessage::NewEmptyMessage(message_class); + if (cmsg == NULL) { + return NULL; + } + + Py_INCREF(self); + cmsg->parent = self; + cmsg->parent_field_descriptor = field_descriptor; + cmsg->read_only = !reflection->HasField(*self->message, field_descriptor); + cmsg->message = const_cast<Message*>(&sub_message); + return cmsg; +} + +int InternalSetNonOneofScalar( + Message* message, + const FieldDescriptor* field_descriptor, + PyObject* arg) { + const Reflection* reflection = message->GetReflection(); + + if (!CheckFieldBelongsToMessage(field_descriptor, message)) { + return -1; + } + + switch (field_descriptor->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: { + GOOGLE_CHECK_GET_INT32(arg, value, -1); + reflection->SetInt32(message, field_descriptor, value); + break; + } + case FieldDescriptor::CPPTYPE_INT64: { + GOOGLE_CHECK_GET_INT64(arg, value, -1); + reflection->SetInt64(message, field_descriptor, value); + break; + } + case FieldDescriptor::CPPTYPE_UINT32: { + GOOGLE_CHECK_GET_UINT32(arg, value, -1); + reflection->SetUInt32(message, field_descriptor, value); + break; + } + case FieldDescriptor::CPPTYPE_UINT64: { + GOOGLE_CHECK_GET_UINT64(arg, value, -1); + reflection->SetUInt64(message, field_descriptor, value); + break; + } + case FieldDescriptor::CPPTYPE_FLOAT: { + GOOGLE_CHECK_GET_FLOAT(arg, value, -1); + reflection->SetFloat(message, field_descriptor, value); + break; + } + case FieldDescriptor::CPPTYPE_DOUBLE: { + GOOGLE_CHECK_GET_DOUBLE(arg, value, -1); + reflection->SetDouble(message, field_descriptor, value); + break; + } + case FieldDescriptor::CPPTYPE_BOOL: { + GOOGLE_CHECK_GET_BOOL(arg, value, -1); + reflection->SetBool(message, field_descriptor, value); + break; + } + case FieldDescriptor::CPPTYPE_STRING: { + if (!CheckAndSetString( + arg, message, field_descriptor, reflection, false, -1)) { + return -1; + } + break; + } + case FieldDescriptor::CPPTYPE_ENUM: { + GOOGLE_CHECK_GET_INT32(arg, value, -1); + if (reflection->SupportsUnknownEnumValues()) { + reflection->SetEnumValue(message, field_descriptor, value); + } else { + const EnumDescriptor* enum_descriptor = field_descriptor->enum_type(); + const EnumValueDescriptor* enum_value = + enum_descriptor->FindValueByNumber(value); + if (enum_value != NULL) { + reflection->SetEnum(message, field_descriptor, enum_value); + } else { + PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value); + return -1; + } + } + break; + } + default: + PyErr_Format( + PyExc_SystemError, "Setting value to a field of unknown type %d", + field_descriptor->cpp_type()); + return -1; + } + + return 0; +} + +int InternalSetScalar( + CMessage* self, + const FieldDescriptor* field_descriptor, + PyObject* arg) { + if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) { + return -1; + } + + if (MaybeReleaseOverlappingOneofField(self, field_descriptor) < 0) { + return -1; + } + + return InternalSetNonOneofScalar(self->message, field_descriptor, arg); +} + +PyObject* FromString(PyTypeObject* cls, PyObject* serialized) { + PyObject* py_cmsg = PyObject_CallObject( + reinterpret_cast<PyObject*>(cls), NULL); + if (py_cmsg == NULL) { + return NULL; + } + CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg); + + ScopedPyObjectPtr py_length(MergeFromString(cmsg, serialized)); + if (py_length == NULL) { + Py_DECREF(py_cmsg); + return NULL; + } + + return py_cmsg; +} + +PyObject* DeepCopy(CMessage* self, PyObject* arg) { + PyObject* clone = PyObject_CallObject( + reinterpret_cast<PyObject*>(Py_TYPE(self)), NULL); + if (clone == NULL) { + return NULL; + } + if (!PyObject_TypeCheck(clone, CMessage_Type)) { + Py_DECREF(clone); + return NULL; + } + if (ScopedPyObjectPtr(MergeFrom( + reinterpret_cast<CMessage*>(clone), + reinterpret_cast<PyObject*>(self))) == NULL) { + Py_DECREF(clone); + return NULL; + } + return clone; +} + +PyObject* ToUnicode(CMessage* self) { + // Lazy import to prevent circular dependencies + ScopedPyObjectPtr text_format( + PyImport_ImportModule("google.protobuf.text_format")); + if (text_format == NULL) { + return NULL; + } + ScopedPyObjectPtr method_name(PyString_FromString("MessageToString")); + if (method_name == NULL) { + return NULL; + } + Py_INCREF(Py_True); + ScopedPyObjectPtr encoded(PyObject_CallMethodObjArgs( + text_format.get(), method_name.get(), self, Py_True, NULL)); + Py_DECREF(Py_True); + if (encoded == NULL) { + return NULL; + } +#if PY_MAJOR_VERSION < 3 + PyObject* decoded = PyString_AsDecodedObject(encoded.get(), "utf-8", NULL); +#else + PyObject* decoded = PyUnicode_FromEncodedObject(encoded.get(), "utf-8", NULL); +#endif + if (decoded == NULL) { + return NULL; + } + return decoded; +} + +// CMessage static methods: +PyObject* _CheckCalledFromGeneratedFile(PyObject* unused, + PyObject* unused_arg) { + if (!_CalledFromGeneratedFile(1)) { + PyErr_SetString(PyExc_TypeError, + "Descriptors should not be created directly, " + "but only retrieved from their parent."); + return NULL; + } + Py_RETURN_NONE; +} + +static PyObject* GetExtensionDict(CMessage* self, void *closure) { + // If there are extension_ranges, the message is "extendable". Allocate a + // dictionary to store the extension fields. + const Descriptor* descriptor = GetMessageDescriptor(Py_TYPE(self)); + if (!descriptor->extension_range_count()) { + PyErr_SetNone(PyExc_AttributeError); + return NULL; + } + if (!self->composite_fields) { + self->composite_fields = new CMessage::CompositeFieldsMap(); + } + if (!self->composite_fields) { + return NULL; + } + ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self); + return reinterpret_cast<PyObject*>(extension_dict); +} + +static PyObject* UnknownFieldSet(CMessage* self) { + if (self->unknown_field_set == NULL) { + self->unknown_field_set = unknown_fields::NewPyUnknownFields(self); + } else { + Py_INCREF(self->unknown_field_set); + } + return self->unknown_field_set; +} + +static PyObject* GetExtensionsByName(CMessage *self, void *closure) { + return message_meta::GetExtensionsByName( + reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure); +} + +static PyObject* GetExtensionsByNumber(CMessage *self, void *closure) { + return message_meta::GetExtensionsByNumber( + reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure); +} + +static PyGetSetDef Getters[] = { + {"Extensions", (getter)GetExtensionDict, NULL, "Extension dict"}, + {"_extensions_by_name", (getter)GetExtensionsByName, NULL}, + {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL}, + {NULL} +}; + + +static PyMethodDef Methods[] = { + { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, + "Makes a deep copy of the class." }, + { "__unicode__", (PyCFunction)ToUnicode, METH_NOARGS, + "Outputs a unicode representation of the message." }, + { "ByteSize", (PyCFunction)ByteSize, METH_NOARGS, + "Returns the size of the message in bytes." }, + { "Clear", (PyCFunction)Clear, METH_NOARGS, + "Clears the message." }, + { "ClearExtension", (PyCFunction)ClearExtension, METH_O, + "Clears a message field." }, + { "ClearField", (PyCFunction)ClearField, METH_O, + "Clears a message field." }, + { "CopyFrom", (PyCFunction)CopyFrom, METH_O, + "Copies a protocol message into the current message." }, + { "DiscardUnknownFields", (PyCFunction)DiscardUnknownFields, METH_NOARGS, + "Discards the unknown fields." }, + { "FindInitializationErrors", (PyCFunction)FindInitializationErrors, + METH_NOARGS, + "Finds unset required fields." }, + { "FromString", (PyCFunction)FromString, METH_O | METH_CLASS, + "Creates new method instance from given serialized data." }, + { "HasExtension", (PyCFunction)HasExtension, METH_O, + "Checks if a message field is set." }, + { "HasField", (PyCFunction)HasField, METH_O, + "Checks if a message field is set." }, + { "IsInitialized", (PyCFunction)IsInitialized, METH_VARARGS, + "Checks if all required fields of a protocol message are set." }, + { "ListFields", (PyCFunction)ListFields, METH_NOARGS, + "Lists all set fields of a message." }, + { "MergeFrom", (PyCFunction)MergeFrom, METH_O, + "Merges a protocol message into the current message." }, + { "MergeFromString", (PyCFunction)MergeFromString, METH_O, + "Merges a serialized message into the current message." }, + { "ParseFromString", (PyCFunction)ParseFromString, METH_O, + "Parses a serialized message into the current message." }, + { "RegisterExtension", (PyCFunction)RegisterExtension, METH_O | METH_CLASS, + "Registers an extension with the current message." }, + { "SerializePartialToString", (PyCFunction)SerializePartialToString, + METH_VARARGS | METH_KEYWORDS, + "Serializes the message to a string, even if it isn't initialized." }, + { "SerializeToString", (PyCFunction)SerializeToString, + METH_VARARGS | METH_KEYWORDS, + "Serializes the message to a string, only for initialized messages." }, + { "SetInParent", (PyCFunction)SetInParent, METH_NOARGS, + "Sets the has bit of the given field in its parent message." }, + { "UnknownFields", (PyCFunction)UnknownFieldSet, METH_NOARGS, + "Parse unknown field set"}, + { "WhichOneof", (PyCFunction)WhichOneof, METH_O, + "Returns the name of the field set inside a oneof, " + "or None if no field is set." }, + + // Static Methods. + { "_CheckCalledFromGeneratedFile", (PyCFunction)_CheckCalledFromGeneratedFile, + METH_NOARGS | METH_STATIC, + "Raises TypeError if the caller is not in a _pb2.py file."}, + { NULL, NULL} +}; + +bool SetCompositeField(CMessage* self, const FieldDescriptor* field, + ContainerBase* value) { + if (self->composite_fields == NULL) { + self->composite_fields = new CMessage::CompositeFieldsMap(); + } + (*self->composite_fields)[field] = value; + return true; +} + +bool SetSubmessage(CMessage* self, CMessage* submessage) { + if (self->child_submessages == NULL) { + self->child_submessages = new CMessage::SubMessagesMap(); + } + (*self->child_submessages)[submessage->message] = submessage; + return true; +} + +PyObject* GetAttr(PyObject* pself, PyObject* name) { + CMessage* self = reinterpret_cast<CMessage*>(pself); + PyObject* result = PyObject_GenericGetAttr( + reinterpret_cast<PyObject*>(self), name); + if (result != NULL) { + return result; + } + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { + return NULL; + } + + PyErr_Clear(); + return message_meta::GetClassAttribute( + CheckMessageClass(Py_TYPE(self)), name); +} + +PyObject* GetFieldValue(CMessage* self, + const FieldDescriptor* field_descriptor) { + if (self->composite_fields) { + CMessage::CompositeFieldsMap::iterator it = + self->composite_fields->find(field_descriptor); + if (it != self->composite_fields->end()) { + ContainerBase* value = it->second; + Py_INCREF(value); + return value->AsPyObject(); + } + } + + if (self->message->GetDescriptor() != field_descriptor->containing_type()) { + PyErr_Format(PyExc_TypeError, + "descriptor to field '%s' doesn't apply to '%s' object", + field_descriptor->full_name().c_str(), + Py_TYPE(self)->tp_name); + return NULL; + } + + if (!field_descriptor->is_repeated() && + field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) { + return InternalGetScalar(self->message, field_descriptor); + } + + ContainerBase* py_container = nullptr; + if (field_descriptor->is_map()) { + const Descriptor* entry_type = field_descriptor->message_type(); + const FieldDescriptor* value_type = entry_type->FindFieldByName("value"); + if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + CMessageClass* value_class = message_factory::GetMessageClass( + GetFactoryForMessage(self), value_type->message_type()); + if (value_class == NULL) { + return NULL; + } + py_container = + NewMessageMapContainer(self, field_descriptor, value_class); + } else { + py_container = NewScalarMapContainer(self, field_descriptor); + } + } else if (field_descriptor->is_repeated()) { + if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + CMessageClass* message_class = message_factory::GetMessageClass( + GetFactoryForMessage(self), field_descriptor->message_type()); + if (message_class == NULL) { + return NULL; + } + py_container = repeated_composite_container::NewContainer( + self, field_descriptor, message_class); + } else { + py_container = + repeated_scalar_container::NewContainer(self, field_descriptor); + } + } else if (field_descriptor->cpp_type() == + FieldDescriptor::CPPTYPE_MESSAGE) { + py_container = InternalGetSubMessage(self, field_descriptor); + } else { + PyErr_SetString(PyExc_SystemError, "Should never happen"); + } + + if (py_container == NULL) { + return NULL; + } + if (!SetCompositeField(self, field_descriptor, py_container)) { + Py_DECREF(py_container); + return NULL; + } + return py_container->AsPyObject(); +} + +int SetFieldValue(CMessage* self, const FieldDescriptor* field_descriptor, + PyObject* value) { + if (self->message->GetDescriptor() != field_descriptor->containing_type()) { + PyErr_Format(PyExc_TypeError, + "descriptor to field '%s' doesn't apply to '%s' object", + field_descriptor->full_name().c_str(), + Py_TYPE(self)->tp_name); + return -1; + } else if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { + PyErr_Format(PyExc_AttributeError, + "Assignment not allowed to repeated " + "field \"%s\" in protocol message object.", + field_descriptor->name().c_str()); + return -1; + } else if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + PyErr_Format(PyExc_AttributeError, + "Assignment not allowed to " + "field \"%s\" in protocol message object.", + field_descriptor->name().c_str()); + return -1; + } else { + AssureWritable(self); + return InternalSetScalar(self, field_descriptor, value); + } +} + +} // namespace cmessage + +// All containers which are not messages: +// - Make a new parent message +// - Copy the field +// - return the field. +PyObject* ContainerBase::DeepCopy() { + CMessage* new_parent = + cmessage::NewEmptyMessage(this->parent->GetMessageClass()); + new_parent->message = this->parent->message->New(nullptr); + + // Copy the map field into the new message. + this->parent->message->GetReflection()->SwapFields( + this->parent->message, new_parent->message, + {this->parent_field_descriptor}); + this->parent->message->MergeFrom(*new_parent->message); + + PyObject* result = + cmessage::GetFieldValue(new_parent, this->parent_field_descriptor); + Py_DECREF(new_parent); + return result; +} + +void ContainerBase::RemoveFromParentCache() { + CMessage* parent = this->parent; + if (parent) { + if (parent->composite_fields) + parent->composite_fields->erase(this->parent_field_descriptor); + Py_CLEAR(parent); + } +} + +CMessage* CMessage::BuildSubMessageFromPointer( + const FieldDescriptor* field_descriptor, Message* sub_message, + CMessageClass* message_class) { + if (!this->child_submessages) { + this->child_submessages = new CMessage::SubMessagesMap(); + } + CMessage* cmsg = FindPtrOrNull( + *this->child_submessages, sub_message); + if (cmsg) { + Py_INCREF(cmsg); + } else { + cmsg = cmessage::NewEmptyMessage(message_class); + + if (cmsg == NULL) { + return NULL; + } + cmsg->message = sub_message; + Py_INCREF(this); + cmsg->parent = this; + cmsg->parent_field_descriptor = field_descriptor; + cmessage::SetSubmessage(this, cmsg); + } + return cmsg; +} + +CMessage* CMessage::MaybeReleaseSubMessage(Message* sub_message) { + if (!this->child_submessages) { + return nullptr; + } + CMessage* released = FindPtrOrNull( + *this->child_submessages, sub_message); + if (!released) { + return nullptr; + } + // The target message will now own its content. + Py_CLEAR(released->parent); + released->parent_field_descriptor = nullptr; + released->read_only = false; + // Delete it from the cache. + this->child_submessages->erase(sub_message); + return released; +} + +static CMessageClass _CMessage_Type = { { { + PyVarObject_HEAD_INIT(&_CMessageClass_Type, 0) + FULL_MODULE_NAME ".CMessage", // tp_name + sizeof(CMessage), // tp_basicsize + 0, // tp_itemsize + (destructor)cmessage::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + (reprfunc)cmessage::ToStr, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + PyObject_HashNotImplemented, // tp_hash + 0, // tp_call + (reprfunc)cmessage::ToStr, // tp_str + cmessage::GetAttr, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE + | Py_TPFLAGS_HAVE_VERSION_TAG, // tp_flags + "A ProtocolMessage", // tp_doc + 0, // tp_traverse + 0, // tp_clear + (richcmpfunc)cmessage::RichCompare, // tp_richcompare + offsetof(CMessage, weakreflist), // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + cmessage::Methods, // tp_methods + 0, // tp_members + cmessage::Getters, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + (initproc)cmessage::Init, // tp_init + 0, // tp_alloc + cmessage::New, // tp_new +} } }; +PyTypeObject* CMessage_Type = &_CMessage_Type.super.ht_type; + +// --- Exposing the C proto living inside Python proto to C code: + +const Message* (*GetCProtoInsidePyProtoPtr)(PyObject* msg); +Message* (*MutableCProtoInsidePyProtoPtr)(PyObject* msg); + +static const Message* GetCProtoInsidePyProtoImpl(PyObject* msg) { + const Message* message = PyMessage_GetMessagePointer(msg); + if (message == NULL) { + PyErr_Clear(); + return NULL; + } + return message; +} + +static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) { + Message* message = PyMessage_GetMutableMessagePointer(msg); + if (message == NULL) { + PyErr_Clear(); + return NULL; + } + return message; +} + +const Message* PyMessage_GetMessagePointer(PyObject* msg) { + if (!PyObject_TypeCheck(msg, CMessage_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a Message instance"); + return NULL; + } + CMessage* cmsg = reinterpret_cast<CMessage*>(msg); + return cmsg->message; +} + +Message* PyMessage_GetMutableMessagePointer(PyObject* msg) { + if (!PyObject_TypeCheck(msg, CMessage_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a Message instance"); + return NULL; + } + CMessage* cmsg = reinterpret_cast<CMessage*>(msg); + + + if ((cmsg->composite_fields && !cmsg->composite_fields->empty()) || + (cmsg->child_submessages && !cmsg->child_submessages->empty())) { + // There is currently no way of accurately syncing arbitrary changes to + // the underlying C++ message back to the CMessage (e.g. removed repeated + // composite containers). We only allow direct mutation of the underlying + // C++ message if there is no child data in the CMessage. + PyErr_SetString(PyExc_ValueError, + "Cannot reliably get a mutable pointer " + "to a message with extra references"); + return NULL; + } + cmessage::AssureWritable(cmsg); + return cmsg->message; +} + +PyObject* PyMessage_New(const Descriptor* descriptor, + PyObject* py_message_factory) { + PyMessageFactory* factory = nullptr; + if (py_message_factory == nullptr) { + factory = GetDescriptorPool_FromPool(descriptor->file()->pool()) + ->py_message_factory; + } else if (PyObject_TypeCheck(py_message_factory, &PyMessageFactory_Type)) { + factory = reinterpret_cast<PyMessageFactory*>(py_message_factory); + } else { + PyErr_SetString(PyExc_TypeError, "Expected a MessageFactory"); + return nullptr; + } + auto* message_class = + message_factory::GetOrCreateMessageClass(factory, descriptor); + if (message_class == nullptr) { + return nullptr; + } + + CMessage* self = cmessage::NewCMessage(message_class); + Py_DECREF(message_class); + if (self == nullptr) { + return nullptr; + } + return self->AsPyObject(); +} + +PyObject* PyMessage_NewMessageOwnedExternally(Message* message, + PyObject* py_message_factory) { + if (py_message_factory) { + PyErr_SetString(PyExc_NotImplementedError, + "Default message_factory=NULL is the only supported value"); + return nullptr; + } + if (message->GetReflection()->GetMessageFactory() != + MessageFactory::generated_factory()) { + PyErr_SetString(PyExc_TypeError, + "Message pointer was not created from the default factory"); + return nullptr; + } + + CMessageClass* message_class = message_factory::GetOrCreateMessageClass( + GetDefaultDescriptorPool()->py_message_factory, message->GetDescriptor()); + if (message_class == nullptr) { + return nullptr; + } + + CMessage* self = cmessage::NewEmptyMessage(message_class); + Py_DECREF(message_class); + if (self == nullptr) { + return nullptr; + } + self->message = message; + Py_INCREF(Py_None); + self->parent = reinterpret_cast<CMessage*>(Py_None); + return self->AsPyObject(); +} + +void InitGlobals() { + // TODO(gps): Check all return values in this function for NULL and propagate + // the error (MemoryError) on up to result in an import failure. These should + // also be freed and reset to NULL during finalization. + kDESCRIPTOR = PyString_FromString("DESCRIPTOR"); + + PyObject *dummy_obj = PySet_New(NULL); + kEmptyWeakref = PyWeakref_NewRef(dummy_obj, NULL); + Py_DECREF(dummy_obj); +} + +bool InitProto2MessageModule(PyObject *m) { + // Initialize types and globals in descriptor.cc + if (!InitDescriptor()) { + return false; + } + + // Initialize types and globals in descriptor_pool.cc + if (!InitDescriptorPool()) { + return false; + } + + // Initialize types and globals in message_factory.cc + if (!InitMessageFactory()) { + return false; + } + + // Initialize constants defined in this file. + InitGlobals(); + + CMessageClass_Type->tp_base = &PyType_Type; + if (PyType_Ready(CMessageClass_Type) < 0) { + return false; + } + PyModule_AddObject(m, "MessageMeta", + reinterpret_cast<PyObject*>(CMessageClass_Type)); + + if (PyType_Ready(CMessage_Type) < 0) { + return false; + } + if (PyType_Ready(CFieldProperty_Type) < 0) { + return false; + } + + // DESCRIPTOR is set on each protocol buffer message class elsewhere, but set + // it here as well to document that subclasses need to set it. + PyDict_SetItem(CMessage_Type->tp_dict, kDESCRIPTOR, Py_None); + // Invalidate any cached data for the CMessage type. + // This call is necessary to correctly support Py_TPFLAGS_HAVE_VERSION_TAG, + // after we have modified CMessage_Type.tp_dict. + PyType_Modified(CMessage_Type); + + PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(CMessage_Type)); + + // Initialize Repeated container types. + { + if (PyType_Ready(&RepeatedScalarContainer_Type) < 0) { + return false; + } + + PyModule_AddObject( + m, "RepeatedScalarContainer", + reinterpret_cast<PyObject*>(&RepeatedScalarContainer_Type)); + + if (PyType_Ready(&RepeatedCompositeContainer_Type) < 0) { + return false; + } + + PyModule_AddObject( + m, "RepeatedCompositeContainer", + reinterpret_cast<PyObject*>(&RepeatedCompositeContainer_Type)); + + // Register them as MutableSequence. +#if PY_MAJOR_VERSION >= 3 + ScopedPyObjectPtr collections(PyImport_ImportModule("collections.abc")); +#else + ScopedPyObjectPtr collections(PyImport_ImportModule("collections")); +#endif + if (collections == NULL) { + return false; + } + ScopedPyObjectPtr mutable_sequence( + PyObject_GetAttrString(collections.get(), "MutableSequence")); + if (mutable_sequence == NULL) { + return false; + } + if (ScopedPyObjectPtr( + PyObject_CallMethod(mutable_sequence.get(), "register", "O", + &RepeatedScalarContainer_Type)) == NULL) { + return false; + } + if (ScopedPyObjectPtr( + PyObject_CallMethod(mutable_sequence.get(), "register", "O", + &RepeatedCompositeContainer_Type)) == NULL) { + return false; + } + } + + if (PyType_Ready(&PyUnknownFields_Type) < 0) { + return false; + } + + PyModule_AddObject(m, "UnknownFieldSet", + reinterpret_cast<PyObject*>(&PyUnknownFields_Type)); + + if (PyType_Ready(&PyUnknownFieldRef_Type) < 0) { + return false; + } + + PyModule_AddObject(m, "UnknownField", + reinterpret_cast<PyObject*>(&PyUnknownFieldRef_Type)); + + // Initialize Map container types. + if (!InitMapContainers()) { + return false; + } + PyModule_AddObject(m, "ScalarMapContainer", + reinterpret_cast<PyObject*>(ScalarMapContainer_Type)); + PyModule_AddObject(m, "MessageMapContainer", + reinterpret_cast<PyObject*>(MessageMapContainer_Type)); + PyModule_AddObject(m, "MapIterator", + reinterpret_cast<PyObject*>(&MapIterator_Type)); + + if (PyType_Ready(&ExtensionDict_Type) < 0) { + return false; + } + PyModule_AddObject(m, "ExtensionDict", + reinterpret_cast<PyObject*>(&ExtensionDict_Type)); + if (PyType_Ready(&ExtensionIterator_Type) < 0) { + return false; + } + PyModule_AddObject(m, "ExtensionIterator", + reinterpret_cast<PyObject*>(&ExtensionIterator_Type)); + + // Expose the DescriptorPool used to hold all descriptors added from generated + // pb2.py files. + // PyModule_AddObject steals a reference. + Py_INCREF(GetDefaultDescriptorPool()); + PyModule_AddObject(m, "default_pool", + reinterpret_cast<PyObject*>(GetDefaultDescriptorPool())); + + PyModule_AddObject(m, "DescriptorPool", + reinterpret_cast<PyObject*>(&PyDescriptorPool_Type)); + PyModule_AddObject(m, "Descriptor", + reinterpret_cast<PyObject*>(&PyMessageDescriptor_Type)); + PyModule_AddObject(m, "FieldDescriptor", + reinterpret_cast<PyObject*>(&PyFieldDescriptor_Type)); + PyModule_AddObject(m, "EnumDescriptor", + reinterpret_cast<PyObject*>(&PyEnumDescriptor_Type)); + PyModule_AddObject(m, "EnumValueDescriptor", + reinterpret_cast<PyObject*>(&PyEnumValueDescriptor_Type)); + PyModule_AddObject(m, "FileDescriptor", + reinterpret_cast<PyObject*>(&PyFileDescriptor_Type)); + PyModule_AddObject(m, "OneofDescriptor", + reinterpret_cast<PyObject*>(&PyOneofDescriptor_Type)); + PyModule_AddObject(m, "ServiceDescriptor", + reinterpret_cast<PyObject*>(&PyServiceDescriptor_Type)); + PyModule_AddObject(m, "MethodDescriptor", + reinterpret_cast<PyObject*>(&PyMethodDescriptor_Type)); + + PyObject* enum_type_wrapper = PyImport_ImportModule( + "google.protobuf.internal.enum_type_wrapper"); + if (enum_type_wrapper == NULL) { + return false; + } + EnumTypeWrapper_class = + PyObject_GetAttrString(enum_type_wrapper, "EnumTypeWrapper"); + Py_DECREF(enum_type_wrapper); + + PyObject* message_module = PyImport_ImportModule( + "google.protobuf.message"); + if (message_module == NULL) { + return false; + } + EncodeError_class = PyObject_GetAttrString(message_module, "EncodeError"); + DecodeError_class = PyObject_GetAttrString(message_module, "DecodeError"); + PythonMessage_class = PyObject_GetAttrString(message_module, "Message"); + Py_DECREF(message_module); + + PyObject* pickle_module = PyImport_ImportModule("pickle"); + if (pickle_module == NULL) { + return false; + } + PickleError_class = PyObject_GetAttrString(pickle_module, "PickleError"); + Py_DECREF(pickle_module); + + // Override {Get,Mutable}CProtoInsidePyProto. + GetCProtoInsidePyProtoPtr = GetCProtoInsidePyProtoImpl; + MutableCProtoInsidePyProtoPtr = MutableCProtoInsidePyProtoImpl; + + return true; +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/message.h b/contrib/python/protobuf/py3/google/protobuf/pyext/message.h new file mode 100644 index 0000000000..ca81a87521 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/message.h @@ -0,0 +1,376 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_H__ + +#include <Python.h> + +#include <cstdint> +#include <memory> +#include <string> +#include <unordered_map> + +#include <google/protobuf/stubs/common.h> + +namespace google { +namespace protobuf { + +class Message; +class Reflection; +class FieldDescriptor; +class Descriptor; +class DescriptorPool; +class MessageFactory; + +namespace python { + +struct ExtensionDict; +struct PyMessageFactory; +struct CMessageClass; + +// Most of the complexity of the Message class comes from the "Release" +// behavior: +// +// When a field is cleared, it is only detached from its message. Existing +// references to submessages, to repeated container etc. won't see any change, +// as if the data was effectively managed by these containers. +// +// ExtensionDicts and UnknownFields containers do NOT follow this rule. They +// don't store any data, and always refer to their parent message. + +struct ContainerBase { + PyObject_HEAD; + + // Strong reference to a parent message object. For a CMessage there are three + // cases: + // - For a top-level message, this pointer is NULL. + // - For a sub-message, this points to the parent message. + // - For a message managed externally, this is a owned reference to Py_None. + // + // For all other types: repeated containers, maps, it always point to a + // valid parent CMessage. + struct CMessage* parent; + + // If this object belongs to a parent message, describes which field it comes + // from. + // The pointer is owned by the DescriptorPool (which is kept alive + // through the message's Python class) + const FieldDescriptor* parent_field_descriptor; + + PyObject* AsPyObject() { return reinterpret_cast<PyObject*>(this); } + + // The Three methods below are only used by Repeated containers, and Maps. + + // This implementation works for all containers which have a parent. + PyObject* DeepCopy(); + // Delete this container object from its parent. Does not work for messages. + void RemoveFromParentCache(); +}; + +typedef struct CMessage : public ContainerBase { + // Pointer to the C++ Message object for this CMessage. + // - If this object has no parent, we own this pointer. + // - If this object has a parent message, the parent owns this pointer. + Message* message; + + // Indicates this submessage is pointing to a default instance of a message. + // Submessages are always first created as read only messages and are then + // made writable, at which point this field is set to false. + bool read_only; + + // A mapping indexed by field, containing weak references to contained objects + // which need to implement the "Release" mechanism: + // direct submessages, RepeatedCompositeContainer, RepeatedScalarContainer + // and MapContainer. + typedef std::unordered_map<const FieldDescriptor*, ContainerBase*> + CompositeFieldsMap; + CompositeFieldsMap* composite_fields; + + // A mapping containing weak references to indirect child messages, accessed + // through containers: repeated messages, and values of message maps. + // This avoid the creation of similar maps in each of those containers. + typedef std::unordered_map<const Message*, CMessage*> SubMessagesMap; + SubMessagesMap* child_submessages; + + // A reference to PyUnknownFields. + PyObject* unknown_field_set; + + // Implements the "weakref" protocol for this object. + PyObject* weakreflist; + + // Return a *borrowed* reference to the message class. + CMessageClass* GetMessageClass() { + return reinterpret_cast<CMessageClass*>(Py_TYPE(this)); + } + + // For container containing messages, return a Python object for the given + // pointer to a message. + CMessage* BuildSubMessageFromPointer(const FieldDescriptor* field_descriptor, + Message* sub_message, + CMessageClass* message_class); + CMessage* MaybeReleaseSubMessage(Message* sub_message); +} CMessage; + +// The (meta) type of all Messages classes. +// It allows us to cache some C++ pointers in the class object itself, they are +// faster to extract than from the type's dictionary. + +struct CMessageClass { + // This is how CPython subclasses C structures: the base structure must be + // the first member of the object. + PyHeapTypeObject super; + + // C++ descriptor of this message. + const Descriptor* message_descriptor; + + // Owned reference, used to keep the pointer above alive. + // This reference must stay alive until all message pointers are destructed. + PyObject* py_message_descriptor; + + // The Python MessageFactory used to create the class. It is needed to resolve + // fields descriptors, including extensions fields; its C++ MessageFactory is + // used to instantiate submessages. + // This reference must stay alive until all message pointers are destructed. + PyMessageFactory* py_message_factory; + + PyObject* AsPyObject() { + return reinterpret_cast<PyObject*>(this); + } +}; + +extern PyTypeObject* CMessageClass_Type; +extern PyTypeObject* CMessage_Type; + +namespace cmessage { + +// Internal function to create a new empty Message Python object, but with empty +// pointers to the C++ objects. +// The caller must fill self->message, self->owner and eventually self->parent. +CMessage* NewEmptyMessage(CMessageClass* type); + +// Retrieves the C++ descriptor of a Python Extension descriptor. +// On error, return NULL with an exception set. +const FieldDescriptor* GetExtensionDescriptor(PyObject* extension); + +// Initializes a new CMessage instance for a submessage. Only called once per +// submessage as the result is cached in composite_fields. +// +// Corresponds to reflection api method GetMessage. +CMessage* InternalGetSubMessage( + CMessage* self, const FieldDescriptor* field_descriptor); + +// Deletes a range of items in a repeated field (following a +// removal in a RepeatedCompositeContainer). +// +// Corresponds to reflection api method RemoveLast. +int DeleteRepeatedField(CMessage* self, + const FieldDescriptor* field_descriptor, + PyObject* slice); + +// Sets the specified scalar value to the message. +int InternalSetScalar(CMessage* self, + const FieldDescriptor* field_descriptor, + PyObject* value); + +// Sets the specified scalar value to the message. Requires it is not a Oneof. +int InternalSetNonOneofScalar(Message* message, + const FieldDescriptor* field_descriptor, + PyObject* arg); + +// Retrieves the specified scalar value from the message. +// +// Returns a new python reference. +PyObject* InternalGetScalar(const Message* message, + const FieldDescriptor* field_descriptor); + +bool SetCompositeField(CMessage* self, const FieldDescriptor* field, + ContainerBase* value); + +bool SetSubmessage(CMessage* self, CMessage* submessage); + +// Clears the message, removing all contained data. Extension dictionary and +// submessages are released first if there are remaining external references. +// +// Corresponds to message api method Clear. +PyObject* Clear(CMessage* self); + +// Clears the data described by the given descriptor. +// Returns -1 on error. +// +// Corresponds to reflection api method ClearField. +int ClearFieldByDescriptor(CMessage* self, const FieldDescriptor* descriptor); + +// Checks if the message has the field described by the descriptor. Used for +// extensions (which have no name). +// Returns 1 if true, 0 if false, and -1 on error. +// +// Corresponds to reflection api method HasField +int HasFieldByDescriptor(CMessage* self, + const FieldDescriptor* field_descriptor); + +// Checks if the message has the named field. +// +// Corresponds to reflection api method HasField. +PyObject* HasField(CMessage* self, PyObject* arg); + +// Initializes values of fields on a newly constructed message. +// Note that positional arguments are disallowed: 'args' must be NULL or the +// empty tuple. +int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs); + +PyObject* MergeFrom(CMessage* self, PyObject* arg); + +// This method does not do anything beyond checking that no other extension +// has been registered with the same field number on this class. +PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle); + +// Get a field from a message. +PyObject* GetFieldValue(CMessage* self, + const FieldDescriptor* field_descriptor); +// Sets the value of a scalar field in a message. +// On error, return -1 with an extension set. +int SetFieldValue(CMessage* self, const FieldDescriptor* field_descriptor, + PyObject* value); + +PyObject* FindInitializationErrors(CMessage* self); + +int AssureWritable(CMessage* self); + +// Returns the message factory for the given message. +// This is equivalent to message.MESSAGE_FACTORY +// +// The returned factory is suitable for finding fields and building submessages, +// even in the case of extensions. +// Returns a *borrowed* reference, and never fails because we pass a CMessage. +PyMessageFactory* GetFactoryForMessage(CMessage* message); + +PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg); + +} // namespace cmessage + + +/* Is 64bit */ +#define IS_64BIT (SIZEOF_LONG == 8) + +#define FIELD_IS_REPEATED(field_descriptor) \ + ((field_descriptor)->label() == FieldDescriptor::LABEL_REPEATED) + +#define GOOGLE_CHECK_GET_INT32(arg, value, err) \ + int32_t value; \ + if (!CheckAndGetInteger(arg, &value)) { \ + return err; \ + } + +#define GOOGLE_CHECK_GET_INT64(arg, value, err) \ + int64_t value; \ + if (!CheckAndGetInteger(arg, &value)) { \ + return err; \ + } + +#define GOOGLE_CHECK_GET_UINT32(arg, value, err) \ + uint32_t value; \ + if (!CheckAndGetInteger(arg, &value)) { \ + return err; \ + } + +#define GOOGLE_CHECK_GET_UINT64(arg, value, err) \ + uint64_t value; \ + if (!CheckAndGetInteger(arg, &value)) { \ + return err; \ + } + +#define GOOGLE_CHECK_GET_FLOAT(arg, value, err) \ + float value; \ + if (!CheckAndGetFloat(arg, &value)) { \ + return err; \ + } + +#define GOOGLE_CHECK_GET_DOUBLE(arg, value, err) \ + double value; \ + if (!CheckAndGetDouble(arg, &value)) { \ + return err; \ + } + +#define GOOGLE_CHECK_GET_BOOL(arg, value, err) \ + bool value; \ + if (!CheckAndGetBool(arg, &value)) { \ + return err; \ + } + +#define FULL_MODULE_NAME "google.protobuf.pyext._message" + +void FormatTypeError(PyObject* arg, const char* expected_types); +template<class T> +bool CheckAndGetInteger(PyObject* arg, T* value); +bool CheckAndGetDouble(PyObject* arg, double* value); +bool CheckAndGetFloat(PyObject* arg, float* value); +bool CheckAndGetBool(PyObject* arg, bool* value); +PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor); +bool CheckAndSetString( + PyObject* arg, Message* message, + const FieldDescriptor* descriptor, + const Reflection* reflection, + bool append, + int index); +PyObject* ToStringObject(const FieldDescriptor* descriptor, + const TProtoStringType& value); + +// Check if the passed field descriptor belongs to the given message. +// If not, return false and set a Python exception (a KeyError) +bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor, + const Message* message); + +extern PyObject* PickleError_class; + +PyObject* PyMessage_New(const Descriptor* descriptor, + PyObject* py_message_factory); +const Message* PyMessage_GetMessagePointer(PyObject* msg); +Message* PyMessage_GetMutableMessagePointer(PyObject* msg); +PyObject* PyMessage_NewMessageOwnedExternally(Message* message, + PyObject* py_message_factory); + +bool InitProto2MessageModule(PyObject *m); + +// These are referenced by repeated_scalar_container, and must +// be explicitly instantiated. +extern template bool CheckAndGetInteger<int32>(PyObject*, int32*); +extern template bool CheckAndGetInteger<int64>(PyObject*, int64*); +extern template bool CheckAndGetInteger<uint32>(PyObject*, uint32*); +extern template bool CheckAndGetInteger<uint64>(PyObject*, uint64*); + +} // namespace python +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_H__ diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/message_factory.cc b/contrib/python/protobuf/py3/google/protobuf/pyext/message_factory.cc new file mode 100644 index 0000000000..7905be0214 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/message_factory.cc @@ -0,0 +1,304 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include <unordered_map> + +#include <Python.h> + +#include <google/protobuf/dynamic_message.h> +#include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/message_factory.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> + +#if PY_MAJOR_VERSION >= 3 + #if PY_VERSION_HEX < 0x03030000 + #error "Python 3.0 - 3.2 are not supported." + #endif + #define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>( \ + PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \ + ? -1 \ + : 0) \ + : PyBytes_AsStringAndSize(ob, (charpp), (sizep))) +#endif + +namespace google { +namespace protobuf { +namespace python { + +namespace message_factory { + +PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool) { + PyMessageFactory* factory = reinterpret_cast<PyMessageFactory*>( + PyType_GenericAlloc(type, 0)); + if (factory == NULL) { + return NULL; + } + + DynamicMessageFactory* message_factory = new DynamicMessageFactory(); + // This option might be the default some day. + message_factory->SetDelegateToGeneratedFactory(true); + factory->message_factory = message_factory; + + factory->pool = pool; + Py_INCREF(pool); + + factory->classes_by_descriptor = new PyMessageFactory::ClassesByMessageMap(); + + return factory; +} + +PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) { + static const char* kwlist[] = {"pool", 0}; + PyObject* pool = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", + const_cast<char**>(kwlist), &pool)) { + return NULL; + } + ScopedPyObjectPtr owned_pool; + if (pool == NULL || pool == Py_None) { + owned_pool.reset(PyObject_CallFunction( + reinterpret_cast<PyObject*>(&PyDescriptorPool_Type), NULL)); + if (owned_pool == NULL) { + return NULL; + } + pool = owned_pool.get(); + } else { + if (!PyObject_TypeCheck(pool, &PyDescriptorPool_Type)) { + PyErr_Format(PyExc_TypeError, "Expected a DescriptorPool, got %s", + pool->ob_type->tp_name); + return NULL; + } + } + + return reinterpret_cast<PyObject*>( + NewMessageFactory(type, reinterpret_cast<PyDescriptorPool*>(pool))); +} + +static void Dealloc(PyObject* pself) { + PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself); + + typedef PyMessageFactory::ClassesByMessageMap::iterator iterator; + for (iterator it = self->classes_by_descriptor->begin(); + it != self->classes_by_descriptor->end(); ++it) { + Py_CLEAR(it->second); + } + delete self->classes_by_descriptor; + delete self->message_factory; + Py_CLEAR(self->pool); + Py_TYPE(self)->tp_free(pself); +} + +static int GcTraverse(PyObject* pself, visitproc visit, void* arg) { + PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself); + Py_VISIT(self->pool); + for (const auto& desc_and_class : *self->classes_by_descriptor) { + Py_VISIT(desc_and_class.second); + } + return 0; +} + +static int GcClear(PyObject* pself) { + PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself); + // Here it's important to not clear self->pool, so that the C++ DescriptorPool + // is still alive when self->message_factory is destructed. + for (auto& desc_and_class : *self->classes_by_descriptor) { + Py_CLEAR(desc_and_class.second); + } + + return 0; +} + +// Add a message class to our database. +int RegisterMessageClass(PyMessageFactory* self, + const Descriptor* message_descriptor, + CMessageClass* message_class) { + Py_INCREF(message_class); + typedef PyMessageFactory::ClassesByMessageMap::iterator iterator; + std::pair<iterator, bool> ret = self->classes_by_descriptor->insert( + std::make_pair(message_descriptor, message_class)); + if (!ret.second) { + // Update case: DECREF the previous value. + Py_DECREF(ret.first->second); + ret.first->second = message_class; + } + return 0; +} + +CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self, + const Descriptor* descriptor) { + // This is the same implementation as MessageFactory.GetPrototype(). + + // Do not create a MessageClass that already exists. + std::unordered_map<const Descriptor*, CMessageClass*>::iterator it = + self->classes_by_descriptor->find(descriptor); + if (it != self->classes_by_descriptor->end()) { + Py_INCREF(it->second); + return it->second; + } + ScopedPyObjectPtr py_descriptor( + PyMessageDescriptor_FromDescriptor(descriptor)); + if (py_descriptor == NULL) { + return NULL; + } + // Create a new message class. + ScopedPyObjectPtr args(Py_BuildValue( + "s(){sOsOsO}", descriptor->name().c_str(), + "DESCRIPTOR", py_descriptor.get(), + "__module__", Py_None, + "message_factory", self)); + if (args == NULL) { + return NULL; + } + ScopedPyObjectPtr message_class(PyObject_CallObject( + reinterpret_cast<PyObject*>(CMessageClass_Type), args.get())); + if (message_class == NULL) { + return NULL; + } + // Create messages class for the messages used by the fields, and registers + // all extensions for these messages during the recursion. + for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) { + const Descriptor* sub_descriptor = + descriptor->field(field_idx)->message_type(); + // It is NULL if the field type is not a message. + if (sub_descriptor != NULL) { + CMessageClass* result = GetOrCreateMessageClass(self, sub_descriptor); + if (result == NULL) { + return NULL; + } + Py_DECREF(result); + } + } + + // Register extensions defined in this message. + for (int ext_idx = 0 ; ext_idx < descriptor->extension_count() ; ext_idx++) { + const FieldDescriptor* extension = descriptor->extension(ext_idx); + ScopedPyObjectPtr py_extended_class( + GetOrCreateMessageClass(self, extension->containing_type()) + ->AsPyObject()); + if (py_extended_class == NULL) { + return NULL; + } + ScopedPyObjectPtr py_extension(PyFieldDescriptor_FromDescriptor(extension)); + if (py_extension == NULL) { + return NULL; + } + ScopedPyObjectPtr result(cmessage::RegisterExtension( + py_extended_class.get(), py_extension.get())); + if (result == NULL) { + return NULL; + } + } + return reinterpret_cast<CMessageClass*>(message_class.release()); +} + +// Retrieve the message class added to our database. +CMessageClass* GetMessageClass(PyMessageFactory* self, + const Descriptor* message_descriptor) { + typedef PyMessageFactory::ClassesByMessageMap::iterator iterator; + iterator ret = self->classes_by_descriptor->find(message_descriptor); + if (ret == self->classes_by_descriptor->end()) { + PyErr_Format(PyExc_TypeError, "No message class registered for '%s'", + message_descriptor->full_name().c_str()); + return NULL; + } else { + return ret->second; + } +} + +static PyMethodDef Methods[] = { + {NULL}}; + +static PyObject* GetPool(PyMessageFactory* self, void* closure) { + Py_INCREF(self->pool); + return reinterpret_cast<PyObject*>(self->pool); +} + +static PyGetSetDef Getters[] = { + {"pool", (getter)GetPool, NULL, "DescriptorPool"}, + {NULL} +}; + +} // namespace message_factory + +PyTypeObject PyMessageFactory_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME + ".MessageFactory", // tp_name + sizeof(PyMessageFactory), // tp_basicsize + 0, // tp_itemsize + message_factory::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, // tp_flags + "A static Message Factory", // tp_doc + message_factory::GcTraverse, // tp_traverse + message_factory::GcClear, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + message_factory::Methods, // tp_methods + 0, // tp_members + message_factory::Getters, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + message_factory::New, // tp_new + PyObject_GC_Del, // tp_free +}; + +bool InitMessageFactory() { + if (PyType_Ready(&PyMessageFactory_Type) < 0) { + return false; + } + + return true; +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/message_factory.h b/contrib/python/protobuf/py3/google/protobuf/pyext/message_factory.h new file mode 100644 index 0000000000..515c29cdb8 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/message_factory.h @@ -0,0 +1,103 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_FACTORY_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_FACTORY_H__ + +#include <Python.h> + +#include <unordered_map> +#include <google/protobuf/descriptor.h> +#include <google/protobuf/pyext/descriptor_pool.h> + +namespace google { +namespace protobuf { +class MessageFactory; + +namespace python { + +// The (meta) type of all Messages classes. +struct CMessageClass; + +struct PyMessageFactory { + PyObject_HEAD + + // DynamicMessageFactory used to create C++ instances of messages. + // This object cache the descriptors that were used, so the DescriptorPool + // needs to get rid of it before it can delete itself. + // + // Note: A C++ MessageFactory is different from the PyMessageFactory. + // The C++ one creates messages, when the Python one creates classes. + MessageFactory* message_factory; + + // Owned reference to a Python DescriptorPool. + // This reference must stay until the message_factory is destructed. + PyDescriptorPool* pool; + + // Make our own mapping to retrieve Python classes from C++ descriptors. + // + // Descriptor pointers stored here are owned by the DescriptorPool above. + // Python references to classes are owned by this PyDescriptorPool. + typedef std::unordered_map<const Descriptor*, CMessageClass*> + ClassesByMessageMap; + ClassesByMessageMap* classes_by_descriptor; +}; + +extern PyTypeObject PyMessageFactory_Type; + +namespace message_factory { + +// Creates a new MessageFactory instance. +PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool); + +// Registers a new Python class for the given message descriptor. +// On error, returns -1 with a Python exception set. +int RegisterMessageClass(PyMessageFactory* self, + const Descriptor* message_descriptor, + CMessageClass* message_class); +// Retrieves the Python class registered with the given message descriptor, or +// fail with a TypeError. Returns a *borrowed* reference. +CMessageClass* GetMessageClass(PyMessageFactory* self, + const Descriptor* message_descriptor); +// Retrieves the Python class registered with the given message descriptor. +// The class is created if not done yet. Returns a *new* reference. +CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self, + const Descriptor* message_descriptor); +} // namespace message_factory + +// Initialize objects used by this module. +// On error, returns false with a Python exception set. +bool InitMessageFactory(); + +} // namespace python +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_FACTORY_H__ diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/message_module.cc b/contrib/python/protobuf/py3/google/protobuf/pyext/message_module.cc new file mode 100644 index 0000000000..b5975f76c5 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/message_module.cc @@ -0,0 +1,132 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include <Python.h> + +#include <google/protobuf/message_lite.h> +#include <google/protobuf/pyext/descriptor_pool.h> +#include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/message_factory.h> +#include <google/protobuf/proto_api.h> + +namespace { + +// C++ API. Clients get at this via proto_api.h +struct ApiImplementation : google::protobuf::python::PyProto_API { + const google::protobuf::Message* GetMessagePointer(PyObject* msg) const override { + return google::protobuf::python::PyMessage_GetMessagePointer(msg); + } + google::protobuf::Message* GetMutableMessagePointer(PyObject* msg) const override { + return google::protobuf::python::PyMessage_GetMutableMessagePointer(msg); + } + const google::protobuf::DescriptorPool* GetDefaultDescriptorPool() const override { + return google::protobuf::python::GetDefaultDescriptorPool()->pool; + } + + google::protobuf::MessageFactory* GetDefaultMessageFactory() const override { + return google::protobuf::python::GetDefaultDescriptorPool() + ->py_message_factory->message_factory; + } + PyObject* NewMessage(const google::protobuf::Descriptor* descriptor, + PyObject* py_message_factory) const override { + return google::protobuf::python::PyMessage_New(descriptor, py_message_factory); + } + PyObject* NewMessageOwnedExternally( + google::protobuf::Message* msg, PyObject* py_message_factory) const override { + return google::protobuf::python::PyMessage_NewMessageOwnedExternally( + msg, py_message_factory); + } +}; + +} // namespace + +static const char module_docstring[] = + "python-proto2 is a module that can be used to enhance proto2 Python API\n" + "performance.\n" + "\n" + "It provides access to the protocol buffers C++ reflection API that\n" + "implements the basic protocol buffer functions."; + +static PyMethodDef ModuleMethods[] = { + {"SetAllowOversizeProtos", + (PyCFunction)google::protobuf::python::cmessage::SetAllowOversizeProtos, METH_O, + "Enable/disable oversize proto parsing."}, + // DO NOT USE: For migration and testing only. + {NULL, NULL}}; + +#if PY_MAJOR_VERSION >= 3 +static struct PyModuleDef _module = {PyModuleDef_HEAD_INIT, + "_message", + module_docstring, + -1, + ModuleMethods, /* m_methods */ + NULL, + NULL, + NULL, + NULL}; +#define INITFUNC PyInit__message +#define INITFUNC_ERRORVAL NULL +#else // Python 2 +#define INITFUNC init_message +#define INITFUNC_ERRORVAL +#endif + +PyMODINIT_FUNC INITFUNC() { + PyObject* m; +#if PY_MAJOR_VERSION >= 3 + m = PyModule_Create(&_module); +#else + m = Py_InitModule3("_message", ModuleMethods, module_docstring); +#endif + if (m == NULL) { + return INITFUNC_ERRORVAL; + } + + if (!google::protobuf::python::InitProto2MessageModule(m)) { + Py_DECREF(m); + return INITFUNC_ERRORVAL; + } + + // Adds the C++ API + if (PyObject* api = PyCapsule_New( + new ApiImplementation(), google::protobuf::python::PyProtoAPICapsuleName(), + [](PyObject* o) { + delete (ApiImplementation*)PyCapsule_GetPointer( + o, google::protobuf::python::PyProtoAPICapsuleName()); + })) { + PyModule_AddObject(m, "proto_API", api); + } else { + return INITFUNC_ERRORVAL; + } + +#if PY_MAJOR_VERSION >= 3 + return m; +#endif +} diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/repeated_composite_container.cc b/contrib/python/protobuf/py3/google/protobuf/pyext/repeated_composite_container.cc new file mode 100644 index 0000000000..f3d6fc3092 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/repeated_composite_container.cc @@ -0,0 +1,612 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#include <google/protobuf/pyext/repeated_composite_container.h> + +#include <memory> + +#include <google/protobuf/stubs/logging.h> +#include <google/protobuf/stubs/common.h> +#include <google/protobuf/descriptor.h> +#include <google/protobuf/dynamic_message.h> +#include <google/protobuf/message.h> +#include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/descriptor_pool.h> +#include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/message_factory.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> +#include <google/protobuf/reflection.h> +#include <google/protobuf/stubs/map_util.h> + +#if PY_MAJOR_VERSION >= 3 + #define PyInt_Check PyLong_Check + #define PyInt_AsLong PyLong_AsLong + #define PyInt_FromLong PyLong_FromLong +#endif + +namespace google { +namespace protobuf { +namespace python { + +namespace repeated_composite_container { + +// --------------------------------------------------------------------- +// len() + +static Py_ssize_t Length(PyObject* pself) { + RepeatedCompositeContainer* self = + reinterpret_cast<RepeatedCompositeContainer*>(pself); + + Message* message = self->parent->message; + return message->GetReflection()->FieldSize(*message, + self->parent_field_descriptor); +} + +// --------------------------------------------------------------------- +// add() + +PyObject* Add(RepeatedCompositeContainer* self, PyObject* args, + PyObject* kwargs) { + if (cmessage::AssureWritable(self->parent) == -1) return nullptr; + Message* message = self->parent->message; + + Message* sub_message = + message->GetReflection()->AddMessage( + message, + self->parent_field_descriptor, + self->child_message_class->py_message_factory->message_factory); + CMessage* cmsg = self->parent->BuildSubMessageFromPointer( + self->parent_field_descriptor, sub_message, self->child_message_class); + + if (cmessage::InitAttributes(cmsg, args, kwargs) < 0) { + message->GetReflection()->RemoveLast( + message, self->parent_field_descriptor); + Py_DECREF(cmsg); + return nullptr; + } + + return cmsg->AsPyObject(); +} + +static PyObject* AddMethod(PyObject* self, PyObject* args, PyObject* kwargs) { + return Add(reinterpret_cast<RepeatedCompositeContainer*>(self), args, kwargs); +} + +// --------------------------------------------------------------------- +// append() + +static PyObject* AddMessage(RepeatedCompositeContainer* self, PyObject* value) { + cmessage::AssureWritable(self->parent); + PyObject* py_cmsg; + Message* message = self->parent->message; + const Reflection* reflection = message->GetReflection(); + py_cmsg = Add(self, nullptr, nullptr); + if (py_cmsg == nullptr) return nullptr; + CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg); + if (ScopedPyObjectPtr(cmessage::MergeFrom(cmsg, value)) == nullptr) { + reflection->RemoveLast( + message, self->parent_field_descriptor); + Py_DECREF(cmsg); + return nullptr; + } + return py_cmsg; +} + +static PyObject* AppendMethod(PyObject* pself, PyObject* value) { + RepeatedCompositeContainer* self = + reinterpret_cast<RepeatedCompositeContainer*>(pself); + ScopedPyObjectPtr py_cmsg(AddMessage(self, value)); + if (py_cmsg == nullptr) { + return nullptr; + } + + Py_RETURN_NONE; +} + +// --------------------------------------------------------------------- +// insert() +static PyObject* Insert(PyObject* pself, PyObject* args) { + RepeatedCompositeContainer* self = + reinterpret_cast<RepeatedCompositeContainer*>(pself); + + Py_ssize_t index; + PyObject* value; + if (!PyArg_ParseTuple(args, "nO", &index, &value)) { + return nullptr; + } + + ScopedPyObjectPtr py_cmsg(AddMessage(self, value)); + if (py_cmsg == nullptr) { + return nullptr; + } + + // Swap the element to right position. + Message* message = self->parent->message; + const Reflection* reflection = message->GetReflection(); + const FieldDescriptor* field_descriptor = self->parent_field_descriptor; + Py_ssize_t length = reflection->FieldSize(*message, field_descriptor) - 1; + Py_ssize_t end_index = index; + if (end_index < 0) end_index += length; + if (end_index < 0) end_index = 0; + for (Py_ssize_t i = length; i > end_index; i --) { + reflection->SwapElements(message, field_descriptor, i, i - 1); + } + + Py_RETURN_NONE; +} + +// --------------------------------------------------------------------- +// extend() + +PyObject* Extend(RepeatedCompositeContainer* self, PyObject* value) { + cmessage::AssureWritable(self->parent); + ScopedPyObjectPtr iter(PyObject_GetIter(value)); + if (iter == nullptr) { + PyErr_SetString(PyExc_TypeError, "Value must be iterable"); + return nullptr; + } + ScopedPyObjectPtr next; + while ((next.reset(PyIter_Next(iter.get()))) != nullptr) { + if (!PyObject_TypeCheck(next.get(), CMessage_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a cmessage"); + return nullptr; + } + ScopedPyObjectPtr new_message(Add(self, nullptr, nullptr)); + if (new_message == nullptr) { + return nullptr; + } + CMessage* new_cmessage = reinterpret_cast<CMessage*>(new_message.get()); + if (ScopedPyObjectPtr(cmessage::MergeFrom(new_cmessage, next.get())) == + nullptr) { + return nullptr; + } + } + if (PyErr_Occurred()) { + return nullptr; + } + Py_RETURN_NONE; +} + +static PyObject* ExtendMethod(PyObject* self, PyObject* value) { + return Extend(reinterpret_cast<RepeatedCompositeContainer*>(self), value); +} + +PyObject* MergeFrom(RepeatedCompositeContainer* self, PyObject* other) { + return Extend(self, other); +} + +static PyObject* MergeFromMethod(PyObject* self, PyObject* other) { + return MergeFrom(reinterpret_cast<RepeatedCompositeContainer*>(self), other); +} + +// This function does not check the bounds. +static PyObject* GetItem(RepeatedCompositeContainer* self, Py_ssize_t index, + Py_ssize_t length = -1) { + if (length == -1) { + Message* message = self->parent->message; + const Reflection* reflection = message->GetReflection(); + length = reflection->FieldSize(*message, self->parent_field_descriptor); + } + if (index < 0 || index >= length) { + PyErr_Format(PyExc_IndexError, "list index (%zd) out of range", index); + return nullptr; + } + Message* message = self->parent->message; + Message* sub_message = message->GetReflection()->MutableRepeatedMessage( + message, self->parent_field_descriptor, index); + return self->parent + ->BuildSubMessageFromPointer(self->parent_field_descriptor, sub_message, + self->child_message_class) + ->AsPyObject(); +} + +PyObject* Subscript(RepeatedCompositeContainer* self, PyObject* item) { + Message* message = self->parent->message; + const Reflection* reflection = message->GetReflection(); + Py_ssize_t length = + reflection->FieldSize(*message, self->parent_field_descriptor); + + if (PyIndex_Check(item)) { + Py_ssize_t index; + index = PyNumber_AsSsize_t(item, PyExc_IndexError); + if (index == -1 && PyErr_Occurred()) return nullptr; + if (index < 0) index += length; + return GetItem(self, index, length); + } else if (PySlice_Check(item)) { + Py_ssize_t from, to, step, slicelength, cur, i; + PyObject* result; + +#if PY_MAJOR_VERSION >= 3 + if (PySlice_GetIndicesEx(item, + length, &from, &to, &step, &slicelength) == -1) { +#else + if (PySlice_GetIndicesEx(reinterpret_cast<PySliceObject*>(item), + length, &from, &to, &step, &slicelength) == -1) { +#endif + return nullptr; + } + + if (slicelength <= 0) { + return PyList_New(0); + } else { + result = PyList_New(slicelength); + if (!result) return nullptr; + + for (cur = from, i = 0; i < slicelength; cur += step, i++) { + PyList_SET_ITEM(result, i, GetItem(self, cur, length)); + } + + return result; + } + } else { + PyErr_Format(PyExc_TypeError, "indices must be integers, not %.200s", + item->ob_type->tp_name); + return nullptr; + } +} + +static PyObject* SubscriptMethod(PyObject* self, PyObject* slice) { + return Subscript(reinterpret_cast<RepeatedCompositeContainer*>(self), slice); +} + +int AssignSubscript(RepeatedCompositeContainer* self, + PyObject* slice, + PyObject* value) { + if (value != nullptr) { + PyErr_SetString(PyExc_TypeError, "does not support assignment"); + return -1; + } + + return cmessage::DeleteRepeatedField(self->parent, + self->parent_field_descriptor, slice); +} + +static int AssignSubscriptMethod(PyObject* self, PyObject* slice, + PyObject* value) { + return AssignSubscript(reinterpret_cast<RepeatedCompositeContainer*>(self), + slice, value); +} + +static PyObject* Remove(PyObject* pself, PyObject* value) { + RepeatedCompositeContainer* self = + reinterpret_cast<RepeatedCompositeContainer*>(pself); + Py_ssize_t len = Length(reinterpret_cast<PyObject*>(self)); + + for (Py_ssize_t i = 0; i < len; i++) { + ScopedPyObjectPtr item(GetItem(self, i, len)); + if (item == nullptr) { + return nullptr; + } + int result = PyObject_RichCompareBool(item.get(), value, Py_EQ); + if (result < 0) { + return nullptr; + } + if (result) { + ScopedPyObjectPtr py_index(PyLong_FromSsize_t(i)); + if (AssignSubscript(self, py_index.get(), nullptr) < 0) { + return nullptr; + } + Py_RETURN_NONE; + } + } + PyErr_SetString(PyExc_ValueError, "Item to delete not in list"); + return nullptr; +} + +static PyObject* RichCompare(PyObject* pself, PyObject* other, int opid) { + RepeatedCompositeContainer* self = + reinterpret_cast<RepeatedCompositeContainer*>(pself); + + if (!PyObject_TypeCheck(other, &RepeatedCompositeContainer_Type)) { + PyErr_SetString(PyExc_TypeError, + "Can only compare repeated composite fields " + "against other repeated composite fields."); + return nullptr; + } + if (opid == Py_EQ || opid == Py_NE) { + // TODO(anuraag): Don't make new lists just for this... + ScopedPyObjectPtr full_slice(PySlice_New(nullptr, nullptr, nullptr)); + if (full_slice == nullptr) { + return nullptr; + } + ScopedPyObjectPtr list(Subscript(self, full_slice.get())); + if (list == nullptr) { + return nullptr; + } + ScopedPyObjectPtr other_list( + Subscript(reinterpret_cast<RepeatedCompositeContainer*>(other), + full_slice.get())); + if (other_list == nullptr) { + return nullptr; + } + return PyObject_RichCompare(list.get(), other_list.get(), opid); + } else { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } +} + +static PyObject* ToStr(PyObject* pself) { + ScopedPyObjectPtr full_slice(PySlice_New(nullptr, nullptr, nullptr)); + if (full_slice == nullptr) { + return nullptr; + } + ScopedPyObjectPtr list(Subscript( + reinterpret_cast<RepeatedCompositeContainer*>(pself), full_slice.get())); + if (list == nullptr) { + return nullptr; + } + return PyObject_Repr(list.get()); +} + +// --------------------------------------------------------------------- +// sort() + +static void ReorderAttached(RepeatedCompositeContainer* self, + PyObject* child_list) { + Message* message = self->parent->message; + const Reflection* reflection = message->GetReflection(); + const FieldDescriptor* descriptor = self->parent_field_descriptor; + const Py_ssize_t length = Length(reinterpret_cast<PyObject*>(self)); + + // We need to rearrange things to match python's sort order. Because there + // was already an O(n*log(n)) step in python and a bunch of reflection, we + // expect an O(n**2) step in C++ won't hurt too much. + for (Py_ssize_t i = 0; i < length; ++i) { + Message* child_message = + reinterpret_cast<CMessage*>(PyList_GET_ITEM(child_list, i))->message; + for (Py_ssize_t j = i; j < length; ++j) { + if (child_message == + &reflection->GetRepeatedMessage(*message, descriptor, j)) { + reflection->SwapElements(message, descriptor, i, j); + break; + } + } + } +} + +// Returns 0 if successful; returns -1 and sets an exception if +// unsuccessful. +static int SortPythonMessages(RepeatedCompositeContainer* self, + PyObject* args, + PyObject* kwds) { + ScopedPyObjectPtr child_list( + PySequence_List(reinterpret_cast<PyObject*>(self))); + if (child_list == nullptr) { + return -1; + } + ScopedPyObjectPtr m(PyObject_GetAttrString(child_list.get(), "sort")); + if (m == nullptr) return -1; + if (ScopedPyObjectPtr(PyObject_Call(m.get(), args, kwds)) == nullptr) + return -1; + ReorderAttached(self, child_list.get()); + return 0; +} + +static PyObject* Sort(PyObject* pself, PyObject* args, PyObject* kwds) { + RepeatedCompositeContainer* self = + reinterpret_cast<RepeatedCompositeContainer*>(pself); + + // Support the old sort_function argument for backwards + // compatibility. + if (kwds != nullptr) { + PyObject* sort_func = PyDict_GetItemString(kwds, "sort_function"); + if (sort_func != nullptr) { + // Must set before deleting as sort_func is a borrowed reference + // and kwds might be the only thing keeping it alive. + PyDict_SetItemString(kwds, "cmp", sort_func); + PyDict_DelItemString(kwds, "sort_function"); + } + } + + if (SortPythonMessages(self, args, kwds) < 0) { + return nullptr; + } + Py_RETURN_NONE; +} + +// --------------------------------------------------------------------- +// reverse() + +// Returns 0 if successful; returns -1 and sets an exception if +// unsuccessful. +static int ReversePythonMessages(RepeatedCompositeContainer* self) { + ScopedPyObjectPtr child_list( + PySequence_List(reinterpret_cast<PyObject*>(self))); + if (child_list == nullptr) { + return -1; + } + if (ScopedPyObjectPtr( + PyObject_CallMethod(child_list.get(), "reverse", nullptr)) == nullptr) + return -1; + ReorderAttached(self, child_list.get()); + return 0; +} + +static PyObject* Reverse(PyObject* pself) { + RepeatedCompositeContainer* self = + reinterpret_cast<RepeatedCompositeContainer*>(pself); + + if (ReversePythonMessages(self) < 0) { + return nullptr; + } + Py_RETURN_NONE; +} + +// --------------------------------------------------------------------- + +static PyObject* Item(PyObject* pself, Py_ssize_t index) { + RepeatedCompositeContainer* self = + reinterpret_cast<RepeatedCompositeContainer*>(pself); + return GetItem(self, index); +} + +static PyObject* Pop(PyObject* pself, PyObject* args) { + RepeatedCompositeContainer* self = + reinterpret_cast<RepeatedCompositeContainer*>(pself); + + Py_ssize_t index = -1; + if (!PyArg_ParseTuple(args, "|n", &index)) { + return nullptr; + } + Py_ssize_t length = Length(pself); + if (index < 0) index += length; + PyObject* item = GetItem(self, index, length); + if (item == nullptr) { + return nullptr; + } + ScopedPyObjectPtr py_index(PyLong_FromSsize_t(index)); + if (AssignSubscript(self, py_index.get(), nullptr) < 0) { + return nullptr; + } + return item; +} + +PyObject* DeepCopy(PyObject* pself, PyObject* arg) { + return reinterpret_cast<RepeatedCompositeContainer*>(pself)->DeepCopy(); +} + +// The private constructor of RepeatedCompositeContainer objects. +RepeatedCompositeContainer *NewContainer( + CMessage* parent, + const FieldDescriptor* parent_field_descriptor, + CMessageClass* child_message_class) { + if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) { + return nullptr; + } + + RepeatedCompositeContainer* self = + reinterpret_cast<RepeatedCompositeContainer*>( + PyType_GenericAlloc(&RepeatedCompositeContainer_Type, 0)); + if (self == nullptr) { + return nullptr; + } + + Py_INCREF(parent); + self->parent = parent; + self->parent_field_descriptor = parent_field_descriptor; + Py_INCREF(child_message_class); + self->child_message_class = child_message_class; + return self; +} + +static void Dealloc(PyObject* pself) { + RepeatedCompositeContainer* self = + reinterpret_cast<RepeatedCompositeContainer*>(pself); + self->RemoveFromParentCache(); + Py_CLEAR(self->child_message_class); + Py_TYPE(self)->tp_free(pself); +} + +static PySequenceMethods SqMethods = { + Length, /* sq_length */ + nullptr, /* sq_concat */ + nullptr, /* sq_repeat */ + Item /* sq_item */ +}; + +static PyMappingMethods MpMethods = { + Length, /* mp_length */ + SubscriptMethod, /* mp_subscript */ + AssignSubscriptMethod, /* mp_ass_subscript */ +}; + +static PyMethodDef Methods[] = { + {"__deepcopy__", DeepCopy, METH_VARARGS, "Makes a deep copy of the class."}, + {"add", reinterpret_cast<PyCFunction>(AddMethod), + METH_VARARGS | METH_KEYWORDS, "Adds an object to the repeated container."}, + {"append", AppendMethod, METH_O, + "Appends a message to the end of the repeated container."}, + {"insert", Insert, METH_VARARGS, + "Inserts a message before the specified index."}, + {"extend", ExtendMethod, METH_O, "Adds objects to the repeated container."}, + {"pop", Pop, METH_VARARGS, + "Removes an object from the repeated container and returns it."}, + {"remove", Remove, METH_O, + "Removes an object from the repeated container."}, + {"sort", reinterpret_cast<PyCFunction>(Sort), METH_VARARGS | METH_KEYWORDS, + "Sorts the repeated container."}, + {"reverse", reinterpret_cast<PyCFunction>(Reverse), METH_NOARGS, + "Reverses elements order of the repeated container."}, + {"MergeFrom", MergeFromMethod, METH_O, + "Adds objects to the repeated container."}, + {nullptr, nullptr}}; + +} // namespace repeated_composite_container + +PyTypeObject RepeatedCompositeContainer_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME + ".RepeatedCompositeContainer", // tp_name + sizeof(RepeatedCompositeContainer), // tp_basicsize + 0, // tp_itemsize + repeated_composite_container::Dealloc, // tp_dealloc +#if PY_VERSION_HEX >= 0x03080000 + 0, // tp_vectorcall_offset +#else + nullptr, // tp_print +#endif + nullptr, // tp_getattr + nullptr, // tp_setattr + nullptr, // tp_compare + repeated_composite_container::ToStr, // tp_repr + nullptr, // tp_as_number + &repeated_composite_container::SqMethods, // tp_as_sequence + &repeated_composite_container::MpMethods, // tp_as_mapping + PyObject_HashNotImplemented, // tp_hash + nullptr, // tp_call + nullptr, // tp_str + nullptr, // tp_getattro + nullptr, // tp_setattro + nullptr, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A Repeated scalar container", // tp_doc + nullptr, // tp_traverse + nullptr, // tp_clear + repeated_composite_container::RichCompare, // tp_richcompare + 0, // tp_weaklistoffset + nullptr, // tp_iter + nullptr, // tp_iternext + repeated_composite_container::Methods, // tp_methods + nullptr, // tp_members + nullptr, // tp_getset + nullptr, // tp_base + nullptr, // tp_dict + nullptr, // tp_descr_get + nullptr, // tp_descr_set + 0, // tp_dictoffset + nullptr, // tp_init +}; + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/repeated_composite_container.h b/contrib/python/protobuf/py3/google/protobuf/pyext/repeated_composite_container.h new file mode 100644 index 0000000000..e241827ef5 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/repeated_composite_container.h @@ -0,0 +1,112 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_COMPOSITE_CONTAINER_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_COMPOSITE_CONTAINER_H__ + +#include <Python.h> + +#include <memory> +#include <string> +#include <vector> + +#include <google/protobuf/pyext/message.h> + +namespace google { +namespace protobuf { + +class FieldDescriptor; +class Message; + +namespace python { + +struct CMessageClass; + +// A RepeatedCompositeContainer always has a parent message. +// The parent message also caches reference to items of the container. +typedef struct RepeatedCompositeContainer : public ContainerBase { + // The type used to create new child messages. + CMessageClass* child_message_class; +} RepeatedCompositeContainer; + +extern PyTypeObject RepeatedCompositeContainer_Type; + +namespace repeated_composite_container { + +// Builds a RepeatedCompositeContainer object, from a parent message and a +// field descriptor. +RepeatedCompositeContainer* NewContainer( + CMessage* parent, + const FieldDescriptor* parent_field_descriptor, + CMessageClass *child_message_class); + +// Appends a new CMessage to the container and returns it. The +// CMessage is initialized using the content of kwargs. +// +// Returns a new reference if successful; returns NULL and sets an +// exception if unsuccessful. +PyObject* Add(RepeatedCompositeContainer* self, + PyObject* args, + PyObject* kwargs); + +// Appends all the CMessages in the input iterator to the container. +// +// Returns None if successful; returns NULL and sets an exception if +// unsuccessful. +PyObject* Extend(RepeatedCompositeContainer* self, PyObject* value); + +// Appends a new message to the container for each message in the +// input iterator, merging each data element in. Equivalent to extend. +// +// Returns None if successful; returns NULL and sets an exception if +// unsuccessful. +PyObject* MergeFrom(RepeatedCompositeContainer* self, PyObject* other); + +// Accesses messages in the container. +// +// Returns a new reference to the message for an integer parameter. +// Returns a new reference to a list of messages for a slice. +PyObject* Subscript(RepeatedCompositeContainer* self, PyObject* slice); + +// Deletes items from the container (cannot be used for assignment). +// +// Returns 0 on success, -1 on failure. +int AssignSubscript(RepeatedCompositeContainer* self, + PyObject* slice, + PyObject* value); +} // namespace repeated_composite_container +} // namespace python +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_COMPOSITE_CONTAINER_H__ diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/repeated_scalar_container.cc b/contrib/python/protobuf/py3/google/protobuf/pyext/repeated_scalar_container.cc new file mode 100644 index 0000000000..3a41a58adb --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/repeated_scalar_container.cc @@ -0,0 +1,796 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#include <google/protobuf/pyext/repeated_scalar_container.h> + +#include <cstdint> +#include <memory> + +#include <google/protobuf/stubs/common.h> +#include <google/protobuf/stubs/logging.h> +#include <google/protobuf/descriptor.h> +#include <google/protobuf/dynamic_message.h> +#include <google/protobuf/message.h> +#include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/descriptor_pool.h> +#include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> + +#if PY_MAJOR_VERSION >= 3 +#define PyInt_FromLong PyLong_FromLong +#if PY_VERSION_HEX < 0x03030000 +#error "Python 3.0 - 3.2 are not supported." +#else +#define PyString_AsString(ob) \ + (PyUnicode_Check(ob) ? PyUnicode_AsUTF8(ob) : PyBytes_AsString(ob)) +#endif +#endif + +namespace google { +namespace protobuf { +namespace python { + +namespace repeated_scalar_container { + +static int InternalAssignRepeatedField(RepeatedScalarContainer* self, + PyObject* list) { + Message* message = self->parent->message; + message->GetReflection()->ClearField(message, self->parent_field_descriptor); + for (Py_ssize_t i = 0; i < PyList_GET_SIZE(list); ++i) { + PyObject* value = PyList_GET_ITEM(list, i); + if (ScopedPyObjectPtr(Append(self, value)) == nullptr) { + return -1; + } + } + return 0; +} + +static Py_ssize_t Len(PyObject* pself) { + RepeatedScalarContainer* self = + reinterpret_cast<RepeatedScalarContainer*>(pself); + Message* message = self->parent->message; + return message->GetReflection()->FieldSize(*message, + self->parent_field_descriptor); +} + +static int AssignItem(PyObject* pself, Py_ssize_t index, PyObject* arg) { + RepeatedScalarContainer* self = + reinterpret_cast<RepeatedScalarContainer*>(pself); + + cmessage::AssureWritable(self->parent); + Message* message = self->parent->message; + const FieldDescriptor* field_descriptor = self->parent_field_descriptor; + + const Reflection* reflection = message->GetReflection(); + int field_size = reflection->FieldSize(*message, field_descriptor); + if (index < 0) { + index = field_size + index; + } + if (index < 0 || index >= field_size) { + PyErr_Format(PyExc_IndexError, "list assignment index (%d) out of range", + static_cast<int>(index)); + return -1; + } + + if (arg == nullptr) { + ScopedPyObjectPtr py_index(PyLong_FromLong(index)); + return cmessage::DeleteRepeatedField(self->parent, field_descriptor, + py_index.get()); + } + + if (PySequence_Check(arg) && !(PyBytes_Check(arg) || PyUnicode_Check(arg))) { + PyErr_SetString(PyExc_TypeError, "Value must be scalar"); + return -1; + } + + switch (field_descriptor->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: { + GOOGLE_CHECK_GET_INT32(arg, value, -1); + reflection->SetRepeatedInt32(message, field_descriptor, index, value); + break; + } + case FieldDescriptor::CPPTYPE_INT64: { + GOOGLE_CHECK_GET_INT64(arg, value, -1); + reflection->SetRepeatedInt64(message, field_descriptor, index, value); + break; + } + case FieldDescriptor::CPPTYPE_UINT32: { + GOOGLE_CHECK_GET_UINT32(arg, value, -1); + reflection->SetRepeatedUInt32(message, field_descriptor, index, value); + break; + } + case FieldDescriptor::CPPTYPE_UINT64: { + GOOGLE_CHECK_GET_UINT64(arg, value, -1); + reflection->SetRepeatedUInt64(message, field_descriptor, index, value); + break; + } + case FieldDescriptor::CPPTYPE_FLOAT: { + GOOGLE_CHECK_GET_FLOAT(arg, value, -1); + reflection->SetRepeatedFloat(message, field_descriptor, index, value); + break; + } + case FieldDescriptor::CPPTYPE_DOUBLE: { + GOOGLE_CHECK_GET_DOUBLE(arg, value, -1); + reflection->SetRepeatedDouble(message, field_descriptor, index, value); + break; + } + case FieldDescriptor::CPPTYPE_BOOL: { + GOOGLE_CHECK_GET_BOOL(arg, value, -1); + reflection->SetRepeatedBool(message, field_descriptor, index, value); + break; + } + case FieldDescriptor::CPPTYPE_STRING: { + if (!CheckAndSetString(arg, message, field_descriptor, reflection, false, + index)) { + return -1; + } + break; + } + case FieldDescriptor::CPPTYPE_ENUM: { + GOOGLE_CHECK_GET_INT32(arg, value, -1); + if (reflection->SupportsUnknownEnumValues()) { + reflection->SetRepeatedEnumValue(message, field_descriptor, index, + value); + } else { + const EnumDescriptor* enum_descriptor = field_descriptor->enum_type(); + const EnumValueDescriptor* enum_value = + enum_descriptor->FindValueByNumber(value); + if (enum_value != nullptr) { + reflection->SetRepeatedEnum(message, field_descriptor, index, + enum_value); + } else { + ScopedPyObjectPtr s(PyObject_Str(arg)); + if (s != nullptr) { + PyErr_Format(PyExc_ValueError, "Unknown enum value: %s", + PyString_AsString(s.get())); + } + return -1; + } + } + break; + } + default: + PyErr_Format(PyExc_SystemError, + "Adding value to a field of unknown type %d", + field_descriptor->cpp_type()); + return -1; + } + return 0; +} + +static PyObject* Item(PyObject* pself, Py_ssize_t index) { + RepeatedScalarContainer* self = + reinterpret_cast<RepeatedScalarContainer*>(pself); + + Message* message = self->parent->message; + const FieldDescriptor* field_descriptor = self->parent_field_descriptor; + const Reflection* reflection = message->GetReflection(); + + int field_size = reflection->FieldSize(*message, field_descriptor); + if (index < 0) { + index = field_size + index; + } + if (index < 0 || index >= field_size) { + PyErr_Format(PyExc_IndexError, "list index (%zd) out of range", index); + return nullptr; + } + + PyObject* result = nullptr; + switch (field_descriptor->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: { + int32_t value = + reflection->GetRepeatedInt32(*message, field_descriptor, index); + result = PyInt_FromLong(value); + break; + } + case FieldDescriptor::CPPTYPE_INT64: { + int64_t value = + reflection->GetRepeatedInt64(*message, field_descriptor, index); + result = PyLong_FromLongLong(value); + break; + } + case FieldDescriptor::CPPTYPE_UINT32: { + uint32_t value = + reflection->GetRepeatedUInt32(*message, field_descriptor, index); + result = PyLong_FromLongLong(value); + break; + } + case FieldDescriptor::CPPTYPE_UINT64: { + uint64_t value = + reflection->GetRepeatedUInt64(*message, field_descriptor, index); + result = PyLong_FromUnsignedLongLong(value); + break; + } + case FieldDescriptor::CPPTYPE_FLOAT: { + float value = + reflection->GetRepeatedFloat(*message, field_descriptor, index); + result = PyFloat_FromDouble(value); + break; + } + case FieldDescriptor::CPPTYPE_DOUBLE: { + double value = + reflection->GetRepeatedDouble(*message, field_descriptor, index); + result = PyFloat_FromDouble(value); + break; + } + case FieldDescriptor::CPPTYPE_BOOL: { + bool value = + reflection->GetRepeatedBool(*message, field_descriptor, index); + result = PyBool_FromLong(value ? 1 : 0); + break; + } + case FieldDescriptor::CPPTYPE_ENUM: { + const EnumValueDescriptor* enum_value = + message->GetReflection()->GetRepeatedEnum(*message, field_descriptor, + index); + result = PyInt_FromLong(enum_value->number()); + break; + } + case FieldDescriptor::CPPTYPE_STRING: { + TProtoStringType scratch; + const TProtoStringType& value = reflection->GetRepeatedStringReference( + *message, field_descriptor, index, &scratch); + result = ToStringObject(field_descriptor, value); + break; + } + default: + PyErr_Format(PyExc_SystemError, + "Getting value from a repeated field of unknown type %d", + field_descriptor->cpp_type()); + } + + return result; +} + +static PyObject* Subscript(PyObject* pself, PyObject* slice) { + Py_ssize_t from; + Py_ssize_t to; + Py_ssize_t step; + Py_ssize_t length; + Py_ssize_t slicelength; + bool return_list = false; +#if PY_MAJOR_VERSION < 3 + if (PyInt_Check(slice)) { + from = to = PyInt_AsLong(slice); + } else // NOLINT +#endif + if (PyLong_Check(slice)) { + from = to = PyLong_AsLong(slice); + } else if (PySlice_Check(slice)) { + length = Len(pself); +#if PY_MAJOR_VERSION >= 3 + if (PySlice_GetIndicesEx(slice, length, &from, &to, &step, &slicelength) == + -1) { +#else + if (PySlice_GetIndicesEx(reinterpret_cast<PySliceObject*>(slice), length, + &from, &to, &step, &slicelength) == -1) { +#endif + return nullptr; + } + return_list = true; + } else { + PyErr_SetString(PyExc_TypeError, "list indices must be integers"); + return nullptr; + } + + if (!return_list) { + return Item(pself, from); + } + + PyObject* list = PyList_New(0); + if (list == nullptr) { + return nullptr; + } + if (from <= to) { + if (step < 0) { + return list; + } + for (Py_ssize_t index = from; index < to; index += step) { + if (index < 0 || index >= length) { + break; + } + ScopedPyObjectPtr s(Item(pself, index)); + PyList_Append(list, s.get()); + } + } else { + if (step > 0) { + return list; + } + for (Py_ssize_t index = from; index > to; index += step) { + if (index < 0 || index >= length) { + break; + } + ScopedPyObjectPtr s(Item(pself, index)); + PyList_Append(list, s.get()); + } + } + return list; +} + +PyObject* Append(RepeatedScalarContainer* self, PyObject* item) { + cmessage::AssureWritable(self->parent); + Message* message = self->parent->message; + const FieldDescriptor* field_descriptor = self->parent_field_descriptor; + + const Reflection* reflection = message->GetReflection(); + switch (field_descriptor->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: { + GOOGLE_CHECK_GET_INT32(item, value, nullptr); + reflection->AddInt32(message, field_descriptor, value); + break; + } + case FieldDescriptor::CPPTYPE_INT64: { + GOOGLE_CHECK_GET_INT64(item, value, nullptr); + reflection->AddInt64(message, field_descriptor, value); + break; + } + case FieldDescriptor::CPPTYPE_UINT32: { + GOOGLE_CHECK_GET_UINT32(item, value, nullptr); + reflection->AddUInt32(message, field_descriptor, value); + break; + } + case FieldDescriptor::CPPTYPE_UINT64: { + GOOGLE_CHECK_GET_UINT64(item, value, nullptr); + reflection->AddUInt64(message, field_descriptor, value); + break; + } + case FieldDescriptor::CPPTYPE_FLOAT: { + GOOGLE_CHECK_GET_FLOAT(item, value, nullptr); + reflection->AddFloat(message, field_descriptor, value); + break; + } + case FieldDescriptor::CPPTYPE_DOUBLE: { + GOOGLE_CHECK_GET_DOUBLE(item, value, nullptr); + reflection->AddDouble(message, field_descriptor, value); + break; + } + case FieldDescriptor::CPPTYPE_BOOL: { + GOOGLE_CHECK_GET_BOOL(item, value, nullptr); + reflection->AddBool(message, field_descriptor, value); + break; + } + case FieldDescriptor::CPPTYPE_STRING: { + if (!CheckAndSetString(item, message, field_descriptor, reflection, true, + -1)) { + return nullptr; + } + break; + } + case FieldDescriptor::CPPTYPE_ENUM: { + GOOGLE_CHECK_GET_INT32(item, value, nullptr); + if (reflection->SupportsUnknownEnumValues()) { + reflection->AddEnumValue(message, field_descriptor, value); + } else { + const EnumDescriptor* enum_descriptor = field_descriptor->enum_type(); + const EnumValueDescriptor* enum_value = + enum_descriptor->FindValueByNumber(value); + if (enum_value != nullptr) { + reflection->AddEnum(message, field_descriptor, enum_value); + } else { + ScopedPyObjectPtr s(PyObject_Str(item)); + if (s != nullptr) { + PyErr_Format(PyExc_ValueError, "Unknown enum value: %s", + PyString_AsString(s.get())); + } + return nullptr; + } + } + break; + } + default: + PyErr_Format(PyExc_SystemError, + "Adding value to a field of unknown type %d", + field_descriptor->cpp_type()); + return nullptr; + } + + Py_RETURN_NONE; +} + +static PyObject* AppendMethod(PyObject* self, PyObject* item) { + return Append(reinterpret_cast<RepeatedScalarContainer*>(self), item); +} + +static int AssSubscript(PyObject* pself, PyObject* slice, PyObject* value) { + RepeatedScalarContainer* self = + reinterpret_cast<RepeatedScalarContainer*>(pself); + + Py_ssize_t from; + Py_ssize_t to; + Py_ssize_t step; + Py_ssize_t length; + Py_ssize_t slicelength; + bool create_list = false; + + cmessage::AssureWritable(self->parent); + Message* message = self->parent->message; + const FieldDescriptor* field_descriptor = self->parent_field_descriptor; + +#if PY_MAJOR_VERSION < 3 + if (PyInt_Check(slice)) { + from = to = PyInt_AsLong(slice); + } else // NOLINT +#endif + if (PyLong_Check(slice)) { + from = to = PyLong_AsLong(slice); + } else if (PySlice_Check(slice)) { + const Reflection* reflection = message->GetReflection(); + length = reflection->FieldSize(*message, field_descriptor); +#if PY_MAJOR_VERSION >= 3 + if (PySlice_GetIndicesEx(slice, length, &from, &to, &step, &slicelength) == + -1) { +#else + if (PySlice_GetIndicesEx(reinterpret_cast<PySliceObject*>(slice), length, + &from, &to, &step, &slicelength) == -1) { +#endif + return -1; + } + create_list = true; + } else { + PyErr_SetString(PyExc_TypeError, "list indices must be integers"); + return -1; + } + + if (value == nullptr) { + return cmessage::DeleteRepeatedField(self->parent, field_descriptor, slice); + } + + if (!create_list) { + return AssignItem(pself, from, value); + } + + ScopedPyObjectPtr full_slice(PySlice_New(nullptr, nullptr, nullptr)); + if (full_slice == nullptr) { + return -1; + } + ScopedPyObjectPtr new_list(Subscript(pself, full_slice.get())); + if (new_list == nullptr) { + return -1; + } + if (PySequence_SetSlice(new_list.get(), from, to, value) < 0) { + return -1; + } + + return InternalAssignRepeatedField(self, new_list.get()); +} + +PyObject* Extend(RepeatedScalarContainer* self, PyObject* value) { + cmessage::AssureWritable(self->parent); + + // TODO(ptucker): Deprecate this behavior. b/18413862 + if (value == Py_None) { + Py_RETURN_NONE; + } + if ((Py_TYPE(value)->tp_as_sequence == nullptr) && PyObject_Not(value)) { + Py_RETURN_NONE; + } + + ScopedPyObjectPtr iter(PyObject_GetIter(value)); + if (iter == nullptr) { + PyErr_SetString(PyExc_TypeError, "Value must be iterable"); + return nullptr; + } + ScopedPyObjectPtr next; + while ((next.reset(PyIter_Next(iter.get()))) != nullptr) { + if (ScopedPyObjectPtr(Append(self, next.get())) == nullptr) { + return nullptr; + } + } + if (PyErr_Occurred()) { + return nullptr; + } + Py_RETURN_NONE; +} + +static PyObject* Insert(PyObject* pself, PyObject* args) { + RepeatedScalarContainer* self = + reinterpret_cast<RepeatedScalarContainer*>(pself); + + Py_ssize_t index; + PyObject* value; + if (!PyArg_ParseTuple(args, "lO", &index, &value)) { + return nullptr; + } + ScopedPyObjectPtr full_slice(PySlice_New(nullptr, nullptr, nullptr)); + ScopedPyObjectPtr new_list(Subscript(pself, full_slice.get())); + if (PyList_Insert(new_list.get(), index, value) < 0) { + return nullptr; + } + int ret = InternalAssignRepeatedField(self, new_list.get()); + if (ret < 0) { + return nullptr; + } + Py_RETURN_NONE; +} + +static PyObject* Remove(PyObject* pself, PyObject* value) { + Py_ssize_t match_index = -1; + for (Py_ssize_t i = 0; i < Len(pself); ++i) { + ScopedPyObjectPtr elem(Item(pself, i)); + if (PyObject_RichCompareBool(elem.get(), value, Py_EQ)) { + match_index = i; + break; + } + } + if (match_index == -1) { + PyErr_SetString(PyExc_ValueError, "remove(x): x not in container"); + return nullptr; + } + if (AssignItem(pself, match_index, nullptr) < 0) { + return nullptr; + } + Py_RETURN_NONE; +} + +static PyObject* ExtendMethod(PyObject* self, PyObject* value) { + return Extend(reinterpret_cast<RepeatedScalarContainer*>(self), value); +} + +static PyObject* RichCompare(PyObject* pself, PyObject* other, int opid) { + if (opid != Py_EQ && opid != Py_NE) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + + // Copy the contents of this repeated scalar container, and other if it is + // also a repeated scalar container, into Python lists so we can delegate + // to the list's compare method. + + ScopedPyObjectPtr full_slice(PySlice_New(nullptr, nullptr, nullptr)); + if (full_slice == nullptr) { + return nullptr; + } + + ScopedPyObjectPtr other_list_deleter; + if (PyObject_TypeCheck(other, &RepeatedScalarContainer_Type)) { + other_list_deleter.reset(Subscript(other, full_slice.get())); + other = other_list_deleter.get(); + } + + ScopedPyObjectPtr list(Subscript(pself, full_slice.get())); + if (list == nullptr) { + return nullptr; + } + return PyObject_RichCompare(list.get(), other, opid); +} + +PyObject* Reduce(PyObject* unused_self, PyObject* unused_other) { + PyErr_Format(PickleError_class, + "can't pickle repeated message fields, convert to list first"); + return nullptr; +} + +static PyObject* Sort(PyObject* pself, PyObject* args, PyObject* kwds) { + // Support the old sort_function argument for backwards + // compatibility. + if (kwds != nullptr) { + PyObject* sort_func = PyDict_GetItemString(kwds, "sort_function"); + if (sort_func != nullptr) { + // Must set before deleting as sort_func is a borrowed reference + // and kwds might be the only thing keeping it alive. + if (PyDict_SetItemString(kwds, "cmp", sort_func) == -1) return nullptr; + if (PyDict_DelItemString(kwds, "sort_function") == -1) return nullptr; + } + } + + ScopedPyObjectPtr full_slice(PySlice_New(nullptr, nullptr, nullptr)); + if (full_slice == nullptr) { + return nullptr; + } + ScopedPyObjectPtr list(Subscript(pself, full_slice.get())); + if (list == nullptr) { + return nullptr; + } + ScopedPyObjectPtr m(PyObject_GetAttrString(list.get(), "sort")); + if (m == nullptr) { + return nullptr; + } + ScopedPyObjectPtr res(PyObject_Call(m.get(), args, kwds)); + if (res == nullptr) { + return nullptr; + } + int ret = InternalAssignRepeatedField( + reinterpret_cast<RepeatedScalarContainer*>(pself), list.get()); + if (ret < 0) { + return nullptr; + } + Py_RETURN_NONE; +} + +static PyObject* Reverse(PyObject* pself) { + ScopedPyObjectPtr full_slice(PySlice_New(nullptr, nullptr, nullptr)); + if (full_slice == nullptr) { + return nullptr; + } + ScopedPyObjectPtr list(Subscript(pself, full_slice.get())); + if (list == nullptr) { + return nullptr; + } + ScopedPyObjectPtr res(PyObject_CallMethod(list.get(), "reverse", nullptr)); + if (res == nullptr) { + return nullptr; + } + int ret = InternalAssignRepeatedField( + reinterpret_cast<RepeatedScalarContainer*>(pself), list.get()); + if (ret < 0) { + return nullptr; + } + Py_RETURN_NONE; +} + +static PyObject* Pop(PyObject* pself, PyObject* args) { + Py_ssize_t index = -1; + if (!PyArg_ParseTuple(args, "|n", &index)) { + return nullptr; + } + PyObject* item = Item(pself, index); + if (item == nullptr) { + PyErr_Format(PyExc_IndexError, "list index (%zd) out of range", index); + return nullptr; + } + if (AssignItem(pself, index, nullptr) < 0) { + return nullptr; + } + return item; +} + +static PyObject* ToStr(PyObject* pself) { + ScopedPyObjectPtr full_slice(PySlice_New(nullptr, nullptr, nullptr)); + if (full_slice == nullptr) { + return nullptr; + } + ScopedPyObjectPtr list(Subscript(pself, full_slice.get())); + if (list == nullptr) { + return nullptr; + } + return PyObject_Repr(list.get()); +} + +static PyObject* MergeFrom(PyObject* pself, PyObject* arg) { + return Extend(reinterpret_cast<RepeatedScalarContainer*>(pself), arg); +} + +// The private constructor of RepeatedScalarContainer objects. +RepeatedScalarContainer* NewContainer( + CMessage* parent, const FieldDescriptor* parent_field_descriptor) { + if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) { + return nullptr; + } + + RepeatedScalarContainer* self = reinterpret_cast<RepeatedScalarContainer*>( + PyType_GenericAlloc(&RepeatedScalarContainer_Type, 0)); + if (self == nullptr) { + return nullptr; + } + + Py_INCREF(parent); + self->parent = parent; + self->parent_field_descriptor = parent_field_descriptor; + + return self; +} + +PyObject* DeepCopy(PyObject* pself, PyObject* arg) { + return reinterpret_cast<RepeatedScalarContainer*>(pself)->DeepCopy(); +} + +static void Dealloc(PyObject* pself) { + reinterpret_cast<RepeatedScalarContainer*>(pself)->RemoveFromParentCache(); + Py_TYPE(pself)->tp_free(pself); +} + +static PySequenceMethods SqMethods = { + Len, /* sq_length */ + nullptr, /* sq_concat */ + nullptr, /* sq_repeat */ + Item, /* sq_item */ + nullptr, /* sq_slice */ + AssignItem /* sq_ass_item */ +}; + +static PyMappingMethods MpMethods = { + Len, /* mp_length */ + Subscript, /* mp_subscript */ + AssSubscript, /* mp_ass_subscript */ +}; + +static PyMethodDef Methods[] = { + {"__deepcopy__", DeepCopy, METH_VARARGS, "Makes a deep copy of the class."}, + {"__reduce__", Reduce, METH_NOARGS, + "Outputs picklable representation of the repeated field."}, + {"append", AppendMethod, METH_O, + "Appends an object to the repeated container."}, + {"extend", ExtendMethod, METH_O, + "Appends objects to the repeated container."}, + {"insert", Insert, METH_VARARGS, + "Inserts an object at the specified position in the container."}, + {"pop", Pop, METH_VARARGS, + "Removes an object from the repeated container and returns it."}, + {"remove", Remove, METH_O, + "Removes an object from the repeated container."}, + {"sort", reinterpret_cast<PyCFunction>(Sort), METH_VARARGS | METH_KEYWORDS, + "Sorts the repeated container."}, + {"reverse", reinterpret_cast<PyCFunction>(Reverse), METH_NOARGS, + "Reverses elements order of the repeated container."}, + {"MergeFrom", static_cast<PyCFunction>(MergeFrom), METH_O, + "Merges a repeated container into the current container."}, + {nullptr, nullptr}}; + +} // namespace repeated_scalar_container + +PyTypeObject RepeatedScalarContainer_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME + ".RepeatedScalarContainer", // tp_name + sizeof(RepeatedScalarContainer), // tp_basicsize + 0, // tp_itemsize + repeated_scalar_container::Dealloc, // tp_dealloc +#if PY_VERSION_HEX >= 0x03080000 + 0, // tp_vectorcall_offset +#else + nullptr, // tp_print +#endif + nullptr, // tp_getattr + nullptr, // tp_setattr + nullptr, // tp_compare + repeated_scalar_container::ToStr, // tp_repr + nullptr, // tp_as_number + &repeated_scalar_container::SqMethods, // tp_as_sequence + &repeated_scalar_container::MpMethods, // tp_as_mapping + PyObject_HashNotImplemented, // tp_hash + nullptr, // tp_call + nullptr, // tp_str + nullptr, // tp_getattro + nullptr, // tp_setattro + nullptr, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A Repeated scalar container", // tp_doc + nullptr, // tp_traverse + nullptr, // tp_clear + repeated_scalar_container::RichCompare, // tp_richcompare + 0, // tp_weaklistoffset + nullptr, // tp_iter + nullptr, // tp_iternext + repeated_scalar_container::Methods, // tp_methods + nullptr, // tp_members + nullptr, // tp_getset + nullptr, // tp_base + nullptr, // tp_dict + nullptr, // tp_descr_get + nullptr, // tp_descr_set + 0, // tp_dictoffset + nullptr, // tp_init +}; + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/repeated_scalar_container.h b/contrib/python/protobuf/py3/google/protobuf/pyext/repeated_scalar_container.h new file mode 100644 index 0000000000..f9f0ea8f31 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/repeated_scalar_container.h @@ -0,0 +1,77 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_SCALAR_CONTAINER_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_SCALAR_CONTAINER_H__ + +#include <Python.h> + +#include <memory> + +#include <google/protobuf/descriptor.h> +#include <google/protobuf/pyext/message.h> + +namespace google { +namespace protobuf { +namespace python { + +typedef struct RepeatedScalarContainer : public ContainerBase { +} RepeatedScalarContainer; + +extern PyTypeObject RepeatedScalarContainer_Type; + +namespace repeated_scalar_container { + +// Builds a RepeatedScalarContainer object, from a parent message and a +// field descriptor. +extern RepeatedScalarContainer* NewContainer( + CMessage* parent, const FieldDescriptor* parent_field_descriptor); + +// Appends the scalar 'item' to the end of the container 'self'. +// +// Returns None if successful; returns NULL and sets an exception if +// unsuccessful. +PyObject* Append(RepeatedScalarContainer* self, PyObject* item); + +// Appends all the elements in the input iterator to the container. +// +// Returns None if successful; returns NULL and sets an exception if +// unsuccessful. +PyObject* Extend(RepeatedScalarContainer* self, PyObject* value); + +} // namespace repeated_scalar_container +} // namespace python +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_SCALAR_CONTAINER_H__ diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/safe_numerics.h b/contrib/python/protobuf/py3/google/protobuf/pyext/safe_numerics.h new file mode 100644 index 0000000000..93ae640e8b --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/safe_numerics.h @@ -0,0 +1,164 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_SAFE_NUMERICS_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_SAFE_NUMERICS_H__ +// Copied from chromium with only changes to the namespace. + +#include <limits> + +#include <google/protobuf/stubs/logging.h> +#include <google/protobuf/stubs/common.h> + +namespace google { +namespace protobuf { +namespace python { + +template <bool SameSize, bool DestLarger, + bool DestIsSigned, bool SourceIsSigned> +struct IsValidNumericCastImpl; + +#define BASE_NUMERIC_CAST_CASE_SPECIALIZATION(A, B, C, D, Code) \ +template <> struct IsValidNumericCastImpl<A, B, C, D> { \ + template <class Source, class DestBounds> static inline bool Test( \ + Source source, DestBounds min, DestBounds max) { \ + return Code; \ + } \ +} + +#define BASE_NUMERIC_CAST_CASE_SAME_SIZE(DestSigned, SourceSigned, Code) \ + BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \ + true, true, DestSigned, SourceSigned, Code); \ + BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \ + true, false, DestSigned, SourceSigned, Code) + +#define BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(DestSigned, SourceSigned, Code) \ + BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \ + false, false, DestSigned, SourceSigned, Code); \ + +#define BASE_NUMERIC_CAST_CASE_DEST_LARGER(DestSigned, SourceSigned, Code) \ + BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \ + false, true, DestSigned, SourceSigned, Code); \ + +// The three top level cases are: +// - Same size +// - Source larger +// - Dest larger +// And for each of those three cases, we handle the 4 different possibilities +// of signed and unsigned. This gives 12 cases to handle, which we enumerate +// below. +// +// The last argument in each of the macros is the actual comparison code. It +// has three arguments available, source (the value), and min/max which are +// the ranges of the destination. + + +// These are the cases where both types have the same size. + +// Both signed. +BASE_NUMERIC_CAST_CASE_SAME_SIZE(true, true, true); +// Both unsigned. +BASE_NUMERIC_CAST_CASE_SAME_SIZE(false, false, true); +// Dest unsigned, Source signed. +BASE_NUMERIC_CAST_CASE_SAME_SIZE(false, true, source >= 0); +// Dest signed, Source unsigned. +// This cast is OK because Dest's max must be less than Source's. +BASE_NUMERIC_CAST_CASE_SAME_SIZE(true, false, + source <= static_cast<Source>(max)); + + +// These are the cases where Source is larger. + +// Both unsigned. +BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(false, false, source <= max); +// Both signed. +BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(true, true, + source >= min && source <= max); +// Dest is unsigned, Source is signed. +BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(false, true, + source >= 0 && source <= max); +// Dest is signed, Source is unsigned. +// This cast is OK because Dest's max must be less than Source's. +BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(true, false, + source <= static_cast<Source>(max)); + + +// These are the cases where Dest is larger. + +// Both unsigned. +BASE_NUMERIC_CAST_CASE_DEST_LARGER(false, false, true); +// Both signed. +BASE_NUMERIC_CAST_CASE_DEST_LARGER(true, true, true); +// Dest is unsigned, Source is signed. +BASE_NUMERIC_CAST_CASE_DEST_LARGER(false, true, source >= 0); +// Dest is signed, Source is unsigned. +BASE_NUMERIC_CAST_CASE_DEST_LARGER(true, false, true); + +#undef BASE_NUMERIC_CAST_CASE_SPECIALIZATION +#undef BASE_NUMERIC_CAST_CASE_SAME_SIZE +#undef BASE_NUMERIC_CAST_CASE_SOURCE_LARGER +#undef BASE_NUMERIC_CAST_CASE_DEST_LARGER + + +// The main test for whether the conversion will under or overflow. +template <class Dest, class Source> +inline bool IsValidNumericCast(Source source) { + typedef std::numeric_limits<Source> SourceLimits; + typedef std::numeric_limits<Dest> DestLimits; + static_assert(SourceLimits::is_specialized, "argument must be numeric"); + static_assert(SourceLimits::is_integer, "argument must be integral"); + static_assert(DestLimits::is_specialized, "result must be numeric"); + static_assert(DestLimits::is_integer, "result must be integral"); + + return IsValidNumericCastImpl< + sizeof(Dest) == sizeof(Source), + (sizeof(Dest) > sizeof(Source)), + DestLimits::is_signed, + SourceLimits::is_signed>::Test( + source, + DestLimits::min(), + DestLimits::max()); +} + +// checked_numeric_cast<> is analogous to static_cast<> for numeric types, +// except that it CHECKs that the specified numeric conversion will not +// overflow or underflow. Floating point arguments are not currently allowed +// (this is static_asserted), though this could be supported if necessary. +template <class Dest, class Source> +inline Dest checked_numeric_cast(Source source) { + GOOGLE_CHECK(IsValidNumericCast<Dest>(source)); + return static_cast<Dest>(source); +} + +} // namespace python +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_SAFE_NUMERICS_H__ diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/scoped_pyobject_ptr.h b/contrib/python/protobuf/py3/google/protobuf/pyext/scoped_pyobject_ptr.h new file mode 100644 index 0000000000..6f7fc29813 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/scoped_pyobject_ptr.h @@ -0,0 +1,100 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: tibell@google.com (Johan Tibell) + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__ + +#include <google/protobuf/stubs/common.h> + +#include <Python.h> +namespace google { +namespace protobuf { +namespace python { + +// Owns a python object and decrements the reference count on destruction. +// This class is not threadsafe. +template <typename PyObjectStruct> +class ScopedPythonPtr { + public: + // Takes the ownership of the specified object to ScopedPythonPtr. + // The reference count of the specified py_object is not incremented. + explicit ScopedPythonPtr(PyObjectStruct* py_object = NULL) + : ptr_(py_object) {} + + // If a PyObject is owned, decrement its reference count. + ~ScopedPythonPtr() { Py_XDECREF(ptr_); } + + // Deletes the current owned object, if any. + // Then takes ownership of a new object without incrementing the reference + // count. + // This function must be called with a reference that you own. + // this->reset(this->get()) is wrong! + // this->reset(this->release()) is OK. + PyObjectStruct* reset(PyObjectStruct* p = NULL) { + Py_XDECREF(ptr_); + ptr_ = p; + return ptr_; + } + + // Releases ownership of the object without decrementing the reference count. + // The caller now owns the returned reference. + PyObjectStruct* release() { + PyObject* p = ptr_; + ptr_ = NULL; + return p; + } + + PyObjectStruct* get() const { return ptr_; } + + PyObject* as_pyobject() const { return reinterpret_cast<PyObject*>(ptr_); } + + // Increments the reference count of the current object. + // Should not be called when no object is held. + void inc() const { Py_INCREF(ptr_); } + + // True when a ScopedPyObjectPtr and a raw pointer refer to the same object. + // Comparison operators are non reflexive. + bool operator==(const PyObjectStruct* p) const { return ptr_ == p; } + bool operator!=(const PyObjectStruct* p) const { return ptr_ != p; } + + private: + PyObjectStruct* ptr_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ScopedPythonPtr); +}; + +typedef ScopedPythonPtr<PyObject> ScopedPyObjectPtr; + +} // namespace python +} // namespace protobuf +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__ diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/unknown_fields.cc b/contrib/python/protobuf/py3/google/protobuf/pyext/unknown_fields.cc new file mode 100644 index 0000000000..deb86e6916 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/unknown_fields.cc @@ -0,0 +1,358 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include <google/protobuf/pyext/unknown_fields.h> + +#include <Python.h> +#include <set> +#include <memory> + +#include <google/protobuf/message.h> +#include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> +#include <google/protobuf/unknown_field_set.h> +#include <google/protobuf/wire_format_lite.h> + +#if PY_MAJOR_VERSION >= 3 + #define PyInt_FromLong PyLong_FromLong +#endif + +namespace google { +namespace protobuf { +namespace python { + +namespace unknown_fields { + +static Py_ssize_t Len(PyObject* pself) { + PyUnknownFields* self = + reinterpret_cast<PyUnknownFields*>(pself); + if (self->fields == NULL) { + PyErr_Format(PyExc_ValueError, + "UnknownFields does not exist. " + "The parent message might be cleared."); + return -1; + } + return self->fields->field_count(); +} + +void Clear(PyUnknownFields* self) { + for (std::set<PyUnknownFields*>::iterator it = + self->sub_unknown_fields.begin(); + it != self->sub_unknown_fields.end(); it++) { + Clear(*it); + } + self->fields = NULL; + self->sub_unknown_fields.clear(); +} + +PyObject* NewPyUnknownFieldRef(PyUnknownFields* parent, + Py_ssize_t index); + +static PyObject* Item(PyObject* pself, Py_ssize_t index) { + PyUnknownFields* self = + reinterpret_cast<PyUnknownFields*>(pself); + if (self->fields == NULL) { + PyErr_Format(PyExc_ValueError, + "UnknownFields does not exist. " + "The parent message might be cleared."); + return NULL; + } + Py_ssize_t total_size = self->fields->field_count(); + if (index < 0) { + index = total_size + index; + } + if (index < 0 || index >= total_size) { + PyErr_Format(PyExc_IndexError, + "index (%zd) out of range", + index); + return NULL; + } + + return unknown_fields::NewPyUnknownFieldRef(self, index); +} + +PyObject* NewPyUnknownFields(CMessage* c_message) { + PyUnknownFields* self = reinterpret_cast<PyUnknownFields*>( + PyType_GenericAlloc(&PyUnknownFields_Type, 0)); + if (self == NULL) { + return NULL; + } + // Call "placement new" to initialize PyUnknownFields. + new (self) PyUnknownFields; + + Py_INCREF(c_message); + self->parent = reinterpret_cast<PyObject*>(c_message); + Message* message = c_message->message; + const Reflection* reflection = message->GetReflection(); + self->fields = &reflection->GetUnknownFields(*message); + + return reinterpret_cast<PyObject*>(self); +} + +PyObject* NewPyUnknownFieldRef(PyUnknownFields* parent, + Py_ssize_t index) { + PyUnknownFieldRef* self = reinterpret_cast<PyUnknownFieldRef*>( + PyType_GenericAlloc(&PyUnknownFieldRef_Type, 0)); + if (self == NULL) { + return NULL; + } + + Py_INCREF(parent); + self->parent = parent; + self->index = index; + + return reinterpret_cast<PyObject*>(self); +} + +static void Dealloc(PyObject* pself) { + PyUnknownFields* self = + reinterpret_cast<PyUnknownFields*>(pself); + if (PyObject_TypeCheck(self->parent, &PyUnknownFields_Type)) { + reinterpret_cast<PyUnknownFields*>( + self->parent)->sub_unknown_fields.erase(self); + } else { + reinterpret_cast<CMessage*>(self->parent)->unknown_field_set = nullptr; + } + Py_CLEAR(self->parent); + self->~PyUnknownFields(); + Py_TYPE(pself)->tp_free(pself); +} + +static PySequenceMethods SqMethods = { + Len, /* sq_length */ + 0, /* sq_concat */ + 0, /* sq_repeat */ + Item, /* sq_item */ + 0, /* sq_slice */ + 0, /* sq_ass_item */ +}; + +} // namespace unknown_fields + +PyTypeObject PyUnknownFields_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".PyUnknownFields", // tp_name + sizeof(PyUnknownFields), // tp_basicsize + 0, // tp_itemsize + unknown_fields::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + &unknown_fields::SqMethods, // tp_as_sequence + 0, // tp_as_mapping + PyObject_HashNotImplemented, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "unknown field set", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + 0, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init +}; + +namespace unknown_field { +static PyObject* PyUnknownFields_FromUnknownFieldSet( + PyUnknownFields* parent, const UnknownFieldSet& fields) { + PyUnknownFields* self = reinterpret_cast<PyUnknownFields*>( + PyType_GenericAlloc(&PyUnknownFields_Type, 0)); + if (self == NULL) { + return NULL; + } + // Call "placement new" to initialize PyUnknownFields. + new (self) PyUnknownFields; + + Py_INCREF(parent); + self->parent = reinterpret_cast<PyObject*>(parent); + self->fields = &fields; + parent->sub_unknown_fields.emplace(self); + + return reinterpret_cast<PyObject*>(self); +} + +const UnknownField* GetUnknownField(PyUnknownFieldRef* self) { + const UnknownFieldSet* fields = self->parent->fields; + if (fields == NULL) { + PyErr_Format(PyExc_ValueError, + "UnknownField does not exist. " + "The parent message might be cleared."); + return NULL; + } + ssize_t total_size = fields->field_count(); + if (self->index >= total_size) { + PyErr_Format(PyExc_ValueError, + "UnknownField does not exist. " + "The parent message might be cleared."); + return NULL; + } + return &fields->field(self->index); +} + +static PyObject* GetFieldNumber(PyUnknownFieldRef* self, void *closure) { + const UnknownField* unknown_field = GetUnknownField(self); + if (unknown_field == NULL) { + return NULL; + } + return PyInt_FromLong(unknown_field->number()); +} + +using internal::WireFormatLite; +static PyObject* GetWireType(PyUnknownFieldRef* self, void *closure) { + const UnknownField* unknown_field = GetUnknownField(self); + if (unknown_field == NULL) { + return NULL; + } + + // Assign a default value to suppress may-uninitialized warnings (errors + // when built in some places). + WireFormatLite::WireType wire_type = WireFormatLite::WIRETYPE_VARINT; + switch (unknown_field->type()) { + case UnknownField::TYPE_VARINT: + wire_type = WireFormatLite::WIRETYPE_VARINT; + break; + case UnknownField::TYPE_FIXED32: + wire_type = WireFormatLite::WIRETYPE_FIXED32; + break; + case UnknownField::TYPE_FIXED64: + wire_type = WireFormatLite::WIRETYPE_FIXED64; + break; + case UnknownField::TYPE_LENGTH_DELIMITED: + wire_type = WireFormatLite::WIRETYPE_LENGTH_DELIMITED; + break; + case UnknownField::TYPE_GROUP: + wire_type = WireFormatLite::WIRETYPE_START_GROUP; + break; + } + return PyInt_FromLong(wire_type); +} + +static PyObject* GetData(PyUnknownFieldRef* self, void *closure) { + const UnknownField* field = GetUnknownField(self); + if (field == NULL) { + return NULL; + } + PyObject* data = NULL; + switch (field->type()) { + case UnknownField::TYPE_VARINT: + data = PyInt_FromLong(field->varint()); + break; + case UnknownField::TYPE_FIXED32: + data = PyInt_FromLong(field->fixed32()); + break; + case UnknownField::TYPE_FIXED64: + data = PyInt_FromLong(field->fixed64()); + break; + case UnknownField::TYPE_LENGTH_DELIMITED: + data = PyBytes_FromStringAndSize(field->length_delimited().data(), + field->GetLengthDelimitedSize()); + break; + case UnknownField::TYPE_GROUP: + data = PyUnknownFields_FromUnknownFieldSet( + self->parent, field->group()); + break; + } + return data; +} + +static void Dealloc(PyObject* pself) { + PyUnknownFieldRef* self = + reinterpret_cast<PyUnknownFieldRef*>(pself); + Py_CLEAR(self->parent); +} + +static PyGetSetDef Getters[] = { + {"field_number", (getter)GetFieldNumber, NULL}, + {"wire_type", (getter)GetWireType, NULL}, + {"data", (getter)GetData, NULL}, + {NULL} +}; + +} // namespace unknown_field + +PyTypeObject PyUnknownFieldRef_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".PyUnknownFieldRef", // tp_name + sizeof(PyUnknownFieldRef), // tp_basicsize + 0, // tp_itemsize + unknown_field::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + PyObject_HashNotImplemented, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "unknown field", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + 0, // tp_methods + 0, // tp_members + unknown_field::Getters, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init +}; + + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/contrib/python/protobuf/py3/google/protobuf/pyext/unknown_fields.h b/contrib/python/protobuf/py3/google/protobuf/pyext/unknown_fields.h new file mode 100644 index 0000000000..94d55e148d --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/pyext/unknown_fields.h @@ -0,0 +1,90 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_UNKNOWN_FIELDS_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_UNKNOWN_FIELDS_H__ + +#include <Python.h> + +#include <memory> +#include <set> + +#include <google/protobuf/pyext/message.h> + +namespace google { +namespace protobuf { + +class UnknownField; +class UnknownFieldSet; + +namespace python { +struct CMessage; + +typedef struct PyUnknownFields { + PyObject_HEAD; + // Strong pointer to the parent CMessage or PyUnknownFields. + // The top PyUnknownFields holds a reference to its parent CMessage + // object before release. + // Sub PyUnknownFields holds reference to parent PyUnknownFields. + PyObject* parent; + + // Pointer to the C++ UnknownFieldSet. + // PyUnknownFields does not own this pointer. + const UnknownFieldSet* fields; + + // Weak references to child unknown fields. + std::set<PyUnknownFields*> sub_unknown_fields; +} PyUnknownFields; + +typedef struct PyUnknownFieldRef { + PyObject_HEAD; + // Every Python PyUnknownFieldRef holds a reference to its parent + // PyUnknownFields in order to keep it alive. + PyUnknownFields* parent; + + // The UnknownField index in UnknownFields. + Py_ssize_t index; +} UknownFieldRef; + +extern PyTypeObject PyUnknownFields_Type; +extern PyTypeObject PyUnknownFieldRef_Type; + +namespace unknown_fields { + +// Builds an PyUnknownFields for a specific message. +PyObject* NewPyUnknownFields(CMessage *parent); +void Clear(PyUnknownFields* self); + +} // namespace unknown_fields +} // namespace python +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_UNKNOWN_FIELDS_H__ diff --git a/contrib/python/protobuf/py3/google/protobuf/reflection.py b/contrib/python/protobuf/py3/google/protobuf/reflection.py new file mode 100644 index 0000000000..81e18859a8 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/reflection.py @@ -0,0 +1,95 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# This code is meant to work on Python 2.4 and above only. + +"""Contains a metaclass and helper functions used to create +protocol message classes from Descriptor objects at runtime. + +Recall that a metaclass is the "type" of a class. +(A class is to a metaclass what an instance is to a class.) + +In this case, we use the GeneratedProtocolMessageType metaclass +to inject all the useful functionality into the classes +output by the protocol compiler at compile-time. + +The upshot of all this is that the real implementation +details for ALL pure-Python protocol buffers are *here in +this file*. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + + +from google.protobuf import message_factory +from google.protobuf import symbol_database + +# The type of all Message classes. +# Part of the public interface, but normally only used by message factories. +GeneratedProtocolMessageType = message_factory._GENERATED_PROTOCOL_MESSAGE_TYPE + +MESSAGE_CLASS_CACHE = {} + + +# Deprecated. Please NEVER use reflection.ParseMessage(). +def ParseMessage(descriptor, byte_str): + """Generate a new Message instance from this Descriptor and a byte string. + + DEPRECATED: ParseMessage is deprecated because it is using MakeClass(). + Please use MessageFactory.GetPrototype() instead. + + Args: + descriptor: Protobuf Descriptor object + byte_str: Serialized protocol buffer byte string + + Returns: + Newly created protobuf Message object. + """ + result_class = MakeClass(descriptor) + new_msg = result_class() + new_msg.ParseFromString(byte_str) + return new_msg + + +# Deprecated. Please NEVER use reflection.MakeClass(). +def MakeClass(descriptor): + """Construct a class object for a protobuf described by descriptor. + + DEPRECATED: use MessageFactory.GetPrototype() instead. + + Args: + descriptor: A descriptor.Descriptor object describing the protobuf. + Returns: + The Message class object described by the descriptor. + """ + # Original implementation leads to duplicate message classes, which won't play + # well with extensions. Message factory info is also missing. + # Redirect to message_factory. + return symbol_database.Default().GetPrototype(descriptor) diff --git a/contrib/python/protobuf/py3/google/protobuf/service.py b/contrib/python/protobuf/py3/google/protobuf/service.py new file mode 100644 index 0000000000..5625246324 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/service.py @@ -0,0 +1,228 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""DEPRECATED: Declares the RPC service interfaces. + +This module declares the abstract interfaces underlying proto2 RPC +services. These are intended to be independent of any particular RPC +implementation, so that proto2 services can be used on top of a variety +of implementations. Starting with version 2.3.0, RPC implementations should +not try to build on these, but should instead provide code generator plugins +which generate code specific to the particular RPC implementation. This way +the generated code can be more appropriate for the implementation in use +and can avoid unnecessary layers of indirection. +""" + +__author__ = 'petar@google.com (Petar Petrov)' + + +class RpcException(Exception): + """Exception raised on failed blocking RPC method call.""" + pass + + +class Service(object): + + """Abstract base interface for protocol-buffer-based RPC services. + + Services themselves are abstract classes (implemented either by servers or as + stubs), but they subclass this base interface. The methods of this + interface can be used to call the methods of the service without knowing + its exact type at compile time (analogous to the Message interface). + """ + + def GetDescriptor(): + """Retrieves this service's descriptor.""" + raise NotImplementedError + + def CallMethod(self, method_descriptor, rpc_controller, + request, done): + """Calls a method of the service specified by method_descriptor. + + If "done" is None then the call is blocking and the response + message will be returned directly. Otherwise the call is asynchronous + and "done" will later be called with the response value. + + In the blocking case, RpcException will be raised on error. + + Preconditions: + + * method_descriptor.service == GetDescriptor + * request is of the exact same classes as returned by + GetRequestClass(method). + * After the call has started, the request must not be modified. + * "rpc_controller" is of the correct type for the RPC implementation being + used by this Service. For stubs, the "correct type" depends on the + RpcChannel which the stub is using. + + Postconditions: + + * "done" will be called when the method is complete. This may be + before CallMethod() returns or it may be at some point in the future. + * If the RPC failed, the response value passed to "done" will be None. + Further details about the failure can be found by querying the + RpcController. + """ + raise NotImplementedError + + def GetRequestClass(self, method_descriptor): + """Returns the class of the request message for the specified method. + + CallMethod() requires that the request is of a particular subclass of + Message. GetRequestClass() gets the default instance of this required + type. + + Example: + method = service.GetDescriptor().FindMethodByName("Foo") + request = stub.GetRequestClass(method)() + request.ParseFromString(input) + service.CallMethod(method, request, callback) + """ + raise NotImplementedError + + def GetResponseClass(self, method_descriptor): + """Returns the class of the response message for the specified method. + + This method isn't really needed, as the RpcChannel's CallMethod constructs + the response protocol message. It's provided anyway in case it is useful + for the caller to know the response type in advance. + """ + raise NotImplementedError + + +class RpcController(object): + + """An RpcController mediates a single method call. + + The primary purpose of the controller is to provide a way to manipulate + settings specific to the RPC implementation and to find out about RPC-level + errors. The methods provided by the RpcController interface are intended + to be a "least common denominator" set of features which we expect all + implementations to support. Specific implementations may provide more + advanced features (e.g. deadline propagation). + """ + + # Client-side methods below + + def Reset(self): + """Resets the RpcController to its initial state. + + After the RpcController has been reset, it may be reused in + a new call. Must not be called while an RPC is in progress. + """ + raise NotImplementedError + + def Failed(self): + """Returns true if the call failed. + + After a call has finished, returns true if the call failed. The possible + reasons for failure depend on the RPC implementation. Failed() must not + be called before a call has finished. If Failed() returns true, the + contents of the response message are undefined. + """ + raise NotImplementedError + + def ErrorText(self): + """If Failed is true, returns a human-readable description of the error.""" + raise NotImplementedError + + def StartCancel(self): + """Initiate cancellation. + + Advises the RPC system that the caller desires that the RPC call be + canceled. The RPC system may cancel it immediately, may wait awhile and + then cancel it, or may not even cancel the call at all. If the call is + canceled, the "done" callback will still be called and the RpcController + will indicate that the call failed at that time. + """ + raise NotImplementedError + + # Server-side methods below + + def SetFailed(self, reason): + """Sets a failure reason. + + Causes Failed() to return true on the client side. "reason" will be + incorporated into the message returned by ErrorText(). If you find + you need to return machine-readable information about failures, you + should incorporate it into your response protocol buffer and should + NOT call SetFailed(). + """ + raise NotImplementedError + + def IsCanceled(self): + """Checks if the client cancelled the RPC. + + If true, indicates that the client canceled the RPC, so the server may + as well give up on replying to it. The server should still call the + final "done" callback. + """ + raise NotImplementedError + + def NotifyOnCancel(self, callback): + """Sets a callback to invoke on cancel. + + Asks that the given callback be called when the RPC is canceled. The + callback will always be called exactly once. If the RPC completes without + being canceled, the callback will be called after completion. If the RPC + has already been canceled when NotifyOnCancel() is called, the callback + will be called immediately. + + NotifyOnCancel() must be called no more than once per request. + """ + raise NotImplementedError + + +class RpcChannel(object): + + """Abstract interface for an RPC channel. + + An RpcChannel represents a communication line to a service which can be used + to call that service's methods. The service may be running on another + machine. Normally, you should not use an RpcChannel directly, but instead + construct a stub {@link Service} wrapping it. Example: + + Example: + RpcChannel channel = rpcImpl.Channel("remotehost.example.com:1234") + RpcController controller = rpcImpl.Controller() + MyService service = MyService_Stub(channel) + service.MyMethod(controller, request, callback) + """ + + def CallMethod(self, method_descriptor, rpc_controller, + request, response_class, done): + """Calls the method identified by the descriptor. + + Call the given method of the remote service. The signature of this + procedure looks the same as Service.CallMethod(), but the requirements + are less strict in one important way: the request object doesn't have to + be of any specific class as long as its descriptor is method.input_type. + """ + raise NotImplementedError diff --git a/contrib/python/protobuf/py3/google/protobuf/service_reflection.py b/contrib/python/protobuf/py3/google/protobuf/service_reflection.py new file mode 100644 index 0000000000..75c51ff322 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/service_reflection.py @@ -0,0 +1,287 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Contains metaclasses used to create protocol service and service stub +classes from ServiceDescriptor objects at runtime. + +The GeneratedServiceType and GeneratedServiceStubType metaclasses are used to +inject all useful functionality into the classes output by the protocol +compiler at compile-time. +""" + +__author__ = 'petar@google.com (Petar Petrov)' + + +class GeneratedServiceType(type): + + """Metaclass for service classes created at runtime from ServiceDescriptors. + + Implementations for all methods described in the Service class are added here + by this class. We also create properties to allow getting/setting all fields + in the protocol message. + + The protocol compiler currently uses this metaclass to create protocol service + classes at runtime. Clients can also manually create their own classes at + runtime, as in this example:: + + mydescriptor = ServiceDescriptor(.....) + class MyProtoService(service.Service): + __metaclass__ = GeneratedServiceType + DESCRIPTOR = mydescriptor + myservice_instance = MyProtoService() + # ... + """ + + _DESCRIPTOR_KEY = 'DESCRIPTOR' + + def __init__(cls, name, bases, dictionary): + """Creates a message service class. + + Args: + name: Name of the class (ignored, but required by the metaclass + protocol). + bases: Base classes of the class being constructed. + dictionary: The class dictionary of the class being constructed. + dictionary[_DESCRIPTOR_KEY] must contain a ServiceDescriptor object + describing this protocol service type. + """ + # Don't do anything if this class doesn't have a descriptor. This happens + # when a service class is subclassed. + if GeneratedServiceType._DESCRIPTOR_KEY not in dictionary: + return + + descriptor = dictionary[GeneratedServiceType._DESCRIPTOR_KEY] + service_builder = _ServiceBuilder(descriptor) + service_builder.BuildService(cls) + cls.DESCRIPTOR = descriptor + + +class GeneratedServiceStubType(GeneratedServiceType): + + """Metaclass for service stubs created at runtime from ServiceDescriptors. + + This class has similar responsibilities as GeneratedServiceType, except that + it creates the service stub classes. + """ + + _DESCRIPTOR_KEY = 'DESCRIPTOR' + + def __init__(cls, name, bases, dictionary): + """Creates a message service stub class. + + Args: + name: Name of the class (ignored, here). + bases: Base classes of the class being constructed. + dictionary: The class dictionary of the class being constructed. + dictionary[_DESCRIPTOR_KEY] must contain a ServiceDescriptor object + describing this protocol service type. + """ + super(GeneratedServiceStubType, cls).__init__(name, bases, dictionary) + # Don't do anything if this class doesn't have a descriptor. This happens + # when a service stub is subclassed. + if GeneratedServiceStubType._DESCRIPTOR_KEY not in dictionary: + return + + descriptor = dictionary[GeneratedServiceStubType._DESCRIPTOR_KEY] + service_stub_builder = _ServiceStubBuilder(descriptor) + service_stub_builder.BuildServiceStub(cls) + + +class _ServiceBuilder(object): + + """This class constructs a protocol service class using a service descriptor. + + Given a service descriptor, this class constructs a class that represents + the specified service descriptor. One service builder instance constructs + exactly one service class. That means all instances of that class share the + same builder. + """ + + def __init__(self, service_descriptor): + """Initializes an instance of the service class builder. + + Args: + service_descriptor: ServiceDescriptor to use when constructing the + service class. + """ + self.descriptor = service_descriptor + + def BuildService(self, cls): + """Constructs the service class. + + Args: + cls: The class that will be constructed. + """ + + # CallMethod needs to operate with an instance of the Service class. This + # internal wrapper function exists only to be able to pass the service + # instance to the method that does the real CallMethod work. + def _WrapCallMethod(srvc, method_descriptor, + rpc_controller, request, callback): + return self._CallMethod(srvc, method_descriptor, + rpc_controller, request, callback) + self.cls = cls + cls.CallMethod = _WrapCallMethod + cls.GetDescriptor = staticmethod(lambda: self.descriptor) + cls.GetDescriptor.__doc__ = "Returns the service descriptor." + cls.GetRequestClass = self._GetRequestClass + cls.GetResponseClass = self._GetResponseClass + for method in self.descriptor.methods: + setattr(cls, method.name, self._GenerateNonImplementedMethod(method)) + + def _CallMethod(self, srvc, method_descriptor, + rpc_controller, request, callback): + """Calls the method described by a given method descriptor. + + Args: + srvc: Instance of the service for which this method is called. + method_descriptor: Descriptor that represent the method to call. + rpc_controller: RPC controller to use for this method's execution. + request: Request protocol message. + callback: A callback to invoke after the method has completed. + """ + if method_descriptor.containing_service != self.descriptor: + raise RuntimeError( + 'CallMethod() given method descriptor for wrong service type.') + method = getattr(srvc, method_descriptor.name) + return method(rpc_controller, request, callback) + + def _GetRequestClass(self, method_descriptor): + """Returns the class of the request protocol message. + + Args: + method_descriptor: Descriptor of the method for which to return the + request protocol message class. + + Returns: + A class that represents the input protocol message of the specified + method. + """ + if method_descriptor.containing_service != self.descriptor: + raise RuntimeError( + 'GetRequestClass() given method descriptor for wrong service type.') + return method_descriptor.input_type._concrete_class + + def _GetResponseClass(self, method_descriptor): + """Returns the class of the response protocol message. + + Args: + method_descriptor: Descriptor of the method for which to return the + response protocol message class. + + Returns: + A class that represents the output protocol message of the specified + method. + """ + if method_descriptor.containing_service != self.descriptor: + raise RuntimeError( + 'GetResponseClass() given method descriptor for wrong service type.') + return method_descriptor.output_type._concrete_class + + def _GenerateNonImplementedMethod(self, method): + """Generates and returns a method that can be set for a service methods. + + Args: + method: Descriptor of the service method for which a method is to be + generated. + + Returns: + A method that can be added to the service class. + """ + return lambda inst, rpc_controller, request, callback: ( + self._NonImplementedMethod(method.name, rpc_controller, callback)) + + def _NonImplementedMethod(self, method_name, rpc_controller, callback): + """The body of all methods in the generated service class. + + Args: + method_name: Name of the method being executed. + rpc_controller: RPC controller used to execute this method. + callback: A callback which will be invoked when the method finishes. + """ + rpc_controller.SetFailed('Method %s not implemented.' % method_name) + callback(None) + + +class _ServiceStubBuilder(object): + + """Constructs a protocol service stub class using a service descriptor. + + Given a service descriptor, this class constructs a suitable stub class. + A stub is just a type-safe wrapper around an RpcChannel which emulates a + local implementation of the service. + + One service stub builder instance constructs exactly one class. It means all + instances of that class share the same service stub builder. + """ + + def __init__(self, service_descriptor): + """Initializes an instance of the service stub class builder. + + Args: + service_descriptor: ServiceDescriptor to use when constructing the + stub class. + """ + self.descriptor = service_descriptor + + def BuildServiceStub(self, cls): + """Constructs the stub class. + + Args: + cls: The class that will be constructed. + """ + + def _ServiceStubInit(stub, rpc_channel): + stub.rpc_channel = rpc_channel + self.cls = cls + cls.__init__ = _ServiceStubInit + for method in self.descriptor.methods: + setattr(cls, method.name, self._GenerateStubMethod(method)) + + def _GenerateStubMethod(self, method): + return (lambda inst, rpc_controller, request, callback=None: + self._StubMethod(inst, method, rpc_controller, request, callback)) + + def _StubMethod(self, stub, method_descriptor, + rpc_controller, request, callback): + """The body of all service methods in the generated stub class. + + Args: + stub: Stub instance. + method_descriptor: Descriptor of the invoked method. + rpc_controller: Rpc controller to execute the method. + request: Request protocol message. + callback: A callback to execute when the method finishes. + Returns: + Response message (in case of blocking call). + """ + return stub.rpc_channel.CallMethod( + method_descriptor, rpc_controller, request, + method_descriptor.output_type._concrete_class, callback) diff --git a/contrib/python/protobuf/py3/google/protobuf/symbol_database.py b/contrib/python/protobuf/py3/google/protobuf/symbol_database.py new file mode 100644 index 0000000000..fdcf8cf06c --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/symbol_database.py @@ -0,0 +1,194 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""A database of Python protocol buffer generated symbols. + +SymbolDatabase is the MessageFactory for messages generated at compile time, +and makes it easy to create new instances of a registered type, given only the +type's protocol buffer symbol name. + +Example usage:: + + db = symbol_database.SymbolDatabase() + + # Register symbols of interest, from one or multiple files. + db.RegisterFileDescriptor(my_proto_pb2.DESCRIPTOR) + db.RegisterMessage(my_proto_pb2.MyMessage) + db.RegisterEnumDescriptor(my_proto_pb2.MyEnum.DESCRIPTOR) + + # The database can be used as a MessageFactory, to generate types based on + # their name: + types = db.GetMessages(['my_proto.proto']) + my_message_instance = types['MyMessage']() + + # The database's underlying descriptor pool can be queried, so it's not + # necessary to know a type's filename to be able to generate it: + filename = db.pool.FindFileContainingSymbol('MyMessage') + my_message_instance = db.GetMessages([filename])['MyMessage']() + + # This functionality is also provided directly via a convenience method: + my_message_instance = db.GetSymbol('MyMessage')() +""" + + +from google.protobuf.internal import api_implementation +from google.protobuf import descriptor_pool +from google.protobuf import message_factory + + +class SymbolDatabase(message_factory.MessageFactory): + """A database of Python generated symbols.""" + + def RegisterMessage(self, message): + """Registers the given message type in the local database. + + Calls to GetSymbol() and GetMessages() will return messages registered here. + + Args: + message: A :class:`google.protobuf.message.Message` subclass (or + instance); its descriptor will be registered. + + Returns: + The provided message. + """ + + desc = message.DESCRIPTOR + self._classes[desc] = message + self.RegisterMessageDescriptor(desc) + return message + + def RegisterMessageDescriptor(self, message_descriptor): + """Registers the given message descriptor in the local database. + + Args: + message_descriptor (Descriptor): the message descriptor to add. + """ + if api_implementation.Type() == 'python': + # pylint: disable=protected-access + self.pool._AddDescriptor(message_descriptor) + + def RegisterEnumDescriptor(self, enum_descriptor): + """Registers the given enum descriptor in the local database. + + Args: + enum_descriptor (EnumDescriptor): The enum descriptor to register. + + Returns: + EnumDescriptor: The provided descriptor. + """ + if api_implementation.Type() == 'python': + # pylint: disable=protected-access + self.pool._AddEnumDescriptor(enum_descriptor) + return enum_descriptor + + def RegisterServiceDescriptor(self, service_descriptor): + """Registers the given service descriptor in the local database. + + Args: + service_descriptor (ServiceDescriptor): the service descriptor to + register. + """ + if api_implementation.Type() == 'python': + # pylint: disable=protected-access + self.pool._AddServiceDescriptor(service_descriptor) + + def RegisterFileDescriptor(self, file_descriptor): + """Registers the given file descriptor in the local database. + + Args: + file_descriptor (FileDescriptor): The file descriptor to register. + """ + if api_implementation.Type() == 'python': + # pylint: disable=protected-access + self.pool._InternalAddFileDescriptor(file_descriptor) + + def GetSymbol(self, symbol): + """Tries to find a symbol in the local database. + + Currently, this method only returns message.Message instances, however, if + may be extended in future to support other symbol types. + + Args: + symbol (str): a protocol buffer symbol. + + Returns: + A Python class corresponding to the symbol. + + Raises: + KeyError: if the symbol could not be found. + """ + + return self._classes[self.pool.FindMessageTypeByName(symbol)] + + def GetMessages(self, files): + # TODO(amauryfa): Fix the differences with MessageFactory. + """Gets all registered messages from a specified file. + + Only messages already created and registered will be returned; (this is the + case for imported _pb2 modules) + But unlike MessageFactory, this version also returns already defined nested + messages, but does not register any message extensions. + + Args: + files (list[str]): The file names to extract messages from. + + Returns: + A dictionary mapping proto names to the message classes. + + Raises: + KeyError: if a file could not be found. + """ + + def _GetAllMessages(desc): + """Walk a message Descriptor and recursively yields all message names.""" + yield desc + for msg_desc in desc.nested_types: + for nested_desc in _GetAllMessages(msg_desc): + yield nested_desc + + result = {} + for file_name in files: + file_desc = self.pool.FindFileByName(file_name) + for msg_desc in file_desc.message_types_by_name.values(): + for desc in _GetAllMessages(msg_desc): + try: + result[desc.full_name] = self._classes[desc] + except KeyError: + # This descriptor has no registered class, skip it. + pass + return result + + +_DEFAULT = SymbolDatabase(pool=descriptor_pool.Default()) + + +def Default(): + """Returns the default SymbolDatabase.""" + return _DEFAULT diff --git a/contrib/python/protobuf/py3/google/protobuf/text_encoding.py b/contrib/python/protobuf/py3/google/protobuf/text_encoding.py new file mode 100644 index 0000000000..39898765f2 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/text_encoding.py @@ -0,0 +1,117 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Encoding related utilities.""" +import re + +import six + +_cescape_chr_to_symbol_map = {} +_cescape_chr_to_symbol_map[9] = r'\t' # optional escape +_cescape_chr_to_symbol_map[10] = r'\n' # optional escape +_cescape_chr_to_symbol_map[13] = r'\r' # optional escape +_cescape_chr_to_symbol_map[34] = r'\"' # necessary escape +_cescape_chr_to_symbol_map[39] = r"\'" # optional escape +_cescape_chr_to_symbol_map[92] = r'\\' # necessary escape + +# Lookup table for unicode +_cescape_unicode_to_str = [chr(i) for i in range(0, 256)] +for byte, string in _cescape_chr_to_symbol_map.items(): + _cescape_unicode_to_str[byte] = string + +# Lookup table for non-utf8, with necessary escapes at (o >= 127 or o < 32) +_cescape_byte_to_str = ([r'\%03o' % i for i in range(0, 32)] + + [chr(i) for i in range(32, 127)] + + [r'\%03o' % i for i in range(127, 256)]) +for byte, string in _cescape_chr_to_symbol_map.items(): + _cescape_byte_to_str[byte] = string +del byte, string + + +def CEscape(text, as_utf8): + # type: (...) -> str + """Escape a bytes string for use in an text protocol buffer. + + Args: + text: A byte string to be escaped. + as_utf8: Specifies if result may contain non-ASCII characters. + In Python 3 this allows unescaped non-ASCII Unicode characters. + In Python 2 the return value will be valid UTF-8 rather than only ASCII. + Returns: + Escaped string (str). + """ + # Python's text.encode() 'string_escape' or 'unicode_escape' codecs do not + # satisfy our needs; they encodes unprintable characters using two-digit hex + # escapes whereas our C++ unescaping function allows hex escapes to be any + # length. So, "\0011".encode('string_escape') ends up being "\\x011", which + # will be decoded in C++ as a single-character string with char code 0x11. + if six.PY3: + text_is_unicode = isinstance(text, str) + if as_utf8 and text_is_unicode: + # We're already unicode, no processing beyond control char escapes. + return text.translate(_cescape_chr_to_symbol_map) + ord_ = ord if text_is_unicode else lambda x: x # bytes iterate as ints. + else: + ord_ = ord # PY2 + if as_utf8: + return ''.join(_cescape_unicode_to_str[ord_(c)] for c in text) + return ''.join(_cescape_byte_to_str[ord_(c)] for c in text) + + +_CUNESCAPE_HEX = re.compile(r'(\\+)x([0-9a-fA-F])(?![0-9a-fA-F])') + + +def CUnescape(text): + # type: (str) -> bytes + """Unescape a text string with C-style escape sequences to UTF-8 bytes. + + Args: + text: The data to parse in a str. + Returns: + A byte string. + """ + + def ReplaceHex(m): + # Only replace the match if the number of leading back slashes is odd. i.e. + # the slash itself is not escaped. + if len(m.group(1)) & 1: + return m.group(1) + 'x0' + m.group(2) + return m.group(0) + + # This is required because the 'string_escape' encoding doesn't + # allow single-digit hex escapes (like '\xf'). + result = _CUNESCAPE_HEX.sub(ReplaceHex, text) + + if six.PY2: + return result.decode('string_escape') + return (result.encode('utf-8') # PY3: Make it bytes to allow decode. + .decode('unicode_escape') + # Make it bytes again to return the proper type. + .encode('raw_unicode_escape')) diff --git a/contrib/python/protobuf/py3/google/protobuf/text_format.py b/contrib/python/protobuf/py3/google/protobuf/text_format.py new file mode 100644 index 0000000000..9c4ca90ee6 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/text_format.py @@ -0,0 +1,1826 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Contains routines for printing protocol messages in text format. + +Simple usage example:: + + # Create a proto object and serialize it to a text proto string. + message = my_proto_pb2.MyMessage(foo='bar') + text_proto = text_format.MessageToString(message) + + # Parse a text proto string. + message = text_format.Parse(text_proto, my_proto_pb2.MyMessage()) +""" + +__author__ = 'kenton@google.com (Kenton Varda)' + +# TODO(b/129989314) Import thread contention leads to test failures. +import encodings.raw_unicode_escape # pylint: disable=unused-import +import encodings.unicode_escape # pylint: disable=unused-import +import io +import math +import re +import six + +from google.protobuf.internal import decoder +from google.protobuf.internal import type_checkers +from google.protobuf import descriptor +from google.protobuf import text_encoding + +if six.PY3: + long = int # pylint: disable=redefined-builtin,invalid-name + +# pylint: disable=g-import-not-at-top +__all__ = ['MessageToString', 'Parse', 'PrintMessage', 'PrintField', + 'PrintFieldValue', 'Merge', 'MessageToBytes'] + +_INTEGER_CHECKERS = (type_checkers.Uint32ValueChecker(), + type_checkers.Int32ValueChecker(), + type_checkers.Uint64ValueChecker(), + type_checkers.Int64ValueChecker()) +_FLOAT_INFINITY = re.compile('-?inf(?:inity)?f?$', re.IGNORECASE) +_FLOAT_NAN = re.compile('nanf?$', re.IGNORECASE) +_QUOTES = frozenset(("'", '"')) +_ANY_FULL_TYPE_NAME = 'google.protobuf.Any' + + +class Error(Exception): + """Top-level module error for text_format.""" + + +class ParseError(Error): + """Thrown in case of text parsing or tokenizing error.""" + + def __init__(self, message=None, line=None, column=None): + if message is not None and line is not None: + loc = str(line) + if column is not None: + loc += ':{0}'.format(column) + message = '{0} : {1}'.format(loc, message) + if message is not None: + super(ParseError, self).__init__(message) + else: + super(ParseError, self).__init__() + self._line = line + self._column = column + + def GetLine(self): + return self._line + + def GetColumn(self): + return self._column + + +class TextWriter(object): + + def __init__(self, as_utf8): + if six.PY2: + self._writer = io.BytesIO() + else: + self._writer = io.StringIO() + + def write(self, val): + if six.PY2: + if isinstance(val, six.text_type): + val = val.encode('utf-8') + return self._writer.write(val) + + def close(self): + return self._writer.close() + + def getvalue(self): + return self._writer.getvalue() + + +def MessageToString( + message, + as_utf8=False, + as_one_line=False, + use_short_repeated_primitives=False, + pointy_brackets=False, + use_index_order=False, + float_format=None, + double_format=None, + use_field_number=False, + descriptor_pool=None, + indent=0, + message_formatter=None, + print_unknown_fields=False, + force_colon=False): + # type: (...) -> str + """Convert protobuf message to text format. + + Double values can be formatted compactly with 15 digits of + precision (which is the most that IEEE 754 "double" can guarantee) + using double_format='.15g'. To ensure that converting to text and back to a + proto will result in an identical value, double_format='.17g' should be used. + + Args: + message: The protocol buffers message. + as_utf8: Return unescaped Unicode for non-ASCII characters. + In Python 3 actual Unicode characters may appear as is in strings. + In Python 2 the return value will be valid UTF-8 rather than only ASCII. + as_one_line: Don't introduce newlines between fields. + use_short_repeated_primitives: Use short repeated format for primitives. + pointy_brackets: If True, use angle brackets instead of curly braces for + nesting. + use_index_order: If True, fields of a proto message will be printed using + the order defined in source code instead of the field number, extensions + will be printed at the end of the message and their relative order is + determined by the extension number. By default, use the field number + order. + float_format (str): If set, use this to specify float field formatting + (per the "Format Specification Mini-Language"); otherwise, shortest float + that has same value in wire will be printed. Also affect double field + if double_format is not set but float_format is set. + double_format (str): If set, use this to specify double field formatting + (per the "Format Specification Mini-Language"); if it is not set but + float_format is set, use float_format. Otherwise, use ``str()`` + use_field_number: If True, print field numbers instead of names. + descriptor_pool (DescriptorPool): Descriptor pool used to resolve Any types. + indent (int): The initial indent level, in terms of spaces, for pretty + print. + message_formatter (function(message, indent, as_one_line) -> unicode|None): + Custom formatter for selected sub-messages (usually based on message + type). Use to pretty print parts of the protobuf for easier diffing. + print_unknown_fields: If True, unknown fields will be printed. + force_colon: If set, a colon will be added after the field name even if the + field is a proto message. + + Returns: + str: A string of the text formatted protocol buffer message. + """ + out = TextWriter(as_utf8) + printer = _Printer( + out, + indent, + as_utf8, + as_one_line, + use_short_repeated_primitives, + pointy_brackets, + use_index_order, + float_format, + double_format, + use_field_number, + descriptor_pool, + message_formatter, + print_unknown_fields=print_unknown_fields, + force_colon=force_colon) + printer.PrintMessage(message) + result = out.getvalue() + out.close() + if as_one_line: + return result.rstrip() + return result + + +def MessageToBytes(message, **kwargs): + # type: (...) -> bytes + """Convert protobuf message to encoded text format. See MessageToString.""" + text = MessageToString(message, **kwargs) + if isinstance(text, bytes): + return text + codec = 'utf-8' if kwargs.get('as_utf8') else 'ascii' + return text.encode(codec) + + +def _IsMapEntry(field): + return (field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and + field.message_type.has_options and + field.message_type.GetOptions().map_entry) + + +def PrintMessage(message, + out, + indent=0, + as_utf8=False, + as_one_line=False, + use_short_repeated_primitives=False, + pointy_brackets=False, + use_index_order=False, + float_format=None, + double_format=None, + use_field_number=False, + descriptor_pool=None, + message_formatter=None, + print_unknown_fields=False, + force_colon=False): + printer = _Printer( + out=out, indent=indent, as_utf8=as_utf8, + as_one_line=as_one_line, + use_short_repeated_primitives=use_short_repeated_primitives, + pointy_brackets=pointy_brackets, + use_index_order=use_index_order, + float_format=float_format, + double_format=double_format, + use_field_number=use_field_number, + descriptor_pool=descriptor_pool, + message_formatter=message_formatter, + print_unknown_fields=print_unknown_fields, + force_colon=force_colon) + printer.PrintMessage(message) + + +def PrintField(field, + value, + out, + indent=0, + as_utf8=False, + as_one_line=False, + use_short_repeated_primitives=False, + pointy_brackets=False, + use_index_order=False, + float_format=None, + double_format=None, + message_formatter=None, + print_unknown_fields=False, + force_colon=False): + """Print a single field name/value pair.""" + printer = _Printer(out, indent, as_utf8, as_one_line, + use_short_repeated_primitives, pointy_brackets, + use_index_order, float_format, double_format, + message_formatter=message_formatter, + print_unknown_fields=print_unknown_fields, + force_colon=force_colon) + printer.PrintField(field, value) + + +def PrintFieldValue(field, + value, + out, + indent=0, + as_utf8=False, + as_one_line=False, + use_short_repeated_primitives=False, + pointy_brackets=False, + use_index_order=False, + float_format=None, + double_format=None, + message_formatter=None, + print_unknown_fields=False, + force_colon=False): + """Print a single field value (not including name).""" + printer = _Printer(out, indent, as_utf8, as_one_line, + use_short_repeated_primitives, pointy_brackets, + use_index_order, float_format, double_format, + message_formatter=message_formatter, + print_unknown_fields=print_unknown_fields, + force_colon=force_colon) + printer.PrintFieldValue(field, value) + + +def _BuildMessageFromTypeName(type_name, descriptor_pool): + """Returns a protobuf message instance. + + Args: + type_name: Fully-qualified protobuf message type name string. + descriptor_pool: DescriptorPool instance. + + Returns: + A Message instance of type matching type_name, or None if the a Descriptor + wasn't found matching type_name. + """ + # pylint: disable=g-import-not-at-top + if descriptor_pool is None: + from google.protobuf import descriptor_pool as pool_mod + descriptor_pool = pool_mod.Default() + from google.protobuf import symbol_database + database = symbol_database.Default() + try: + message_descriptor = descriptor_pool.FindMessageTypeByName(type_name) + except KeyError: + return None + message_type = database.GetPrototype(message_descriptor) + return message_type() + + +# These values must match WireType enum in google/protobuf/wire_format.h. +WIRETYPE_LENGTH_DELIMITED = 2 +WIRETYPE_START_GROUP = 3 + + +class _Printer(object): + """Text format printer for protocol message.""" + + def __init__( + self, + out, + indent=0, + as_utf8=False, + as_one_line=False, + use_short_repeated_primitives=False, + pointy_brackets=False, + use_index_order=False, + float_format=None, + double_format=None, + use_field_number=False, + descriptor_pool=None, + message_formatter=None, + print_unknown_fields=False, + force_colon=False): + """Initialize the Printer. + + Double values can be formatted compactly with 15 digits of precision + (which is the most that IEEE 754 "double" can guarantee) using + double_format='.15g'. To ensure that converting to text and back to a proto + will result in an identical value, double_format='.17g' should be used. + + Args: + out: To record the text format result. + indent: The initial indent level for pretty print. + as_utf8: Return unescaped Unicode for non-ASCII characters. + In Python 3 actual Unicode characters may appear as is in strings. + In Python 2 the return value will be valid UTF-8 rather than ASCII. + as_one_line: Don't introduce newlines between fields. + use_short_repeated_primitives: Use short repeated format for primitives. + pointy_brackets: If True, use angle brackets instead of curly braces for + nesting. + use_index_order: If True, print fields of a proto message using the order + defined in source code instead of the field number. By default, use the + field number order. + float_format: If set, use this to specify float field formatting + (per the "Format Specification Mini-Language"); otherwise, shortest + float that has same value in wire will be printed. Also affect double + field if double_format is not set but float_format is set. + double_format: If set, use this to specify double field formatting + (per the "Format Specification Mini-Language"); if it is not set but + float_format is set, use float_format. Otherwise, str() is used. + use_field_number: If True, print field numbers instead of names. + descriptor_pool: A DescriptorPool used to resolve Any types. + message_formatter: A function(message, indent, as_one_line): unicode|None + to custom format selected sub-messages (usually based on message type). + Use to pretty print parts of the protobuf for easier diffing. + print_unknown_fields: If True, unknown fields will be printed. + force_colon: If set, a colon will be added after the field name even if + the field is a proto message. + """ + self.out = out + self.indent = indent + self.as_utf8 = as_utf8 + self.as_one_line = as_one_line + self.use_short_repeated_primitives = use_short_repeated_primitives + self.pointy_brackets = pointy_brackets + self.use_index_order = use_index_order + self.float_format = float_format + if double_format is not None: + self.double_format = double_format + else: + self.double_format = float_format + self.use_field_number = use_field_number + self.descriptor_pool = descriptor_pool + self.message_formatter = message_formatter + self.print_unknown_fields = print_unknown_fields + self.force_colon = force_colon + + def _TryPrintAsAnyMessage(self, message): + """Serializes if message is a google.protobuf.Any field.""" + if '/' not in message.type_url: + return False + packed_message = _BuildMessageFromTypeName(message.TypeName(), + self.descriptor_pool) + if packed_message: + packed_message.MergeFromString(message.value) + colon = ':' if self.force_colon else '' + self.out.write('%s[%s]%s ' % (self.indent * ' ', message.type_url, colon)) + self._PrintMessageFieldValue(packed_message) + self.out.write(' ' if self.as_one_line else '\n') + return True + else: + return False + + def _TryCustomFormatMessage(self, message): + formatted = self.message_formatter(message, self.indent, self.as_one_line) + if formatted is None: + return False + + out = self.out + out.write(' ' * self.indent) + out.write(formatted) + out.write(' ' if self.as_one_line else '\n') + return True + + def PrintMessage(self, message): + """Convert protobuf message to text format. + + Args: + message: The protocol buffers message. + """ + if self.message_formatter and self._TryCustomFormatMessage(message): + return + if (message.DESCRIPTOR.full_name == _ANY_FULL_TYPE_NAME and + self._TryPrintAsAnyMessage(message)): + return + fields = message.ListFields() + if self.use_index_order: + fields.sort( + key=lambda x: x[0].number if x[0].is_extension else x[0].index) + for field, value in fields: + if _IsMapEntry(field): + for key in sorted(value): + # This is slow for maps with submessage entries because it copies the + # entire tree. Unfortunately this would take significant refactoring + # of this file to work around. + # + # TODO(haberman): refactor and optimize if this becomes an issue. + entry_submsg = value.GetEntryClass()(key=key, value=value[key]) + self.PrintField(field, entry_submsg) + elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + if (self.use_short_repeated_primitives + and field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE + and field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_STRING): + self._PrintShortRepeatedPrimitivesValue(field, value) + else: + for element in value: + self.PrintField(field, element) + else: + self.PrintField(field, value) + + if self.print_unknown_fields: + self._PrintUnknownFields(message.UnknownFields()) + + def _PrintUnknownFields(self, unknown_fields): + """Print unknown fields.""" + out = self.out + for field in unknown_fields: + out.write(' ' * self.indent) + out.write(str(field.field_number)) + if field.wire_type == WIRETYPE_START_GROUP: + if self.as_one_line: + out.write(' { ') + else: + out.write(' {\n') + self.indent += 2 + + self._PrintUnknownFields(field.data) + + if self.as_one_line: + out.write('} ') + else: + self.indent -= 2 + out.write(' ' * self.indent + '}\n') + elif field.wire_type == WIRETYPE_LENGTH_DELIMITED: + try: + # If this field is parseable as a Message, it is probably + # an embedded message. + # pylint: disable=protected-access + (embedded_unknown_message, pos) = decoder._DecodeUnknownFieldSet( + memoryview(field.data), 0, len(field.data)) + except Exception: # pylint: disable=broad-except + pos = 0 + + if pos == len(field.data): + if self.as_one_line: + out.write(' { ') + else: + out.write(' {\n') + self.indent += 2 + + self._PrintUnknownFields(embedded_unknown_message) + + if self.as_one_line: + out.write('} ') + else: + self.indent -= 2 + out.write(' ' * self.indent + '}\n') + else: + # A string or bytes field. self.as_utf8 may not work. + out.write(': \"') + out.write(text_encoding.CEscape(field.data, False)) + out.write('\" ' if self.as_one_line else '\"\n') + else: + # varint, fixed32, fixed64 + out.write(': ') + out.write(str(field.data)) + out.write(' ' if self.as_one_line else '\n') + + def _PrintFieldName(self, field): + """Print field name.""" + out = self.out + out.write(' ' * self.indent) + if self.use_field_number: + out.write(str(field.number)) + else: + if field.is_extension: + out.write('[') + if (field.containing_type.GetOptions().message_set_wire_format and + field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and + field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL): + out.write(field.message_type.full_name) + else: + out.write(field.full_name) + out.write(']') + elif field.type == descriptor.FieldDescriptor.TYPE_GROUP: + # For groups, use the capitalized name. + out.write(field.message_type.name) + else: + out.write(field.name) + + if (self.force_colon or + field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE): + # The colon is optional in this case, but our cross-language golden files + # don't include it. Here, the colon is only included if force_colon is + # set to True + out.write(':') + + def PrintField(self, field, value): + """Print a single field name/value pair.""" + self._PrintFieldName(field) + self.out.write(' ') + self.PrintFieldValue(field, value) + self.out.write(' ' if self.as_one_line else '\n') + + def _PrintShortRepeatedPrimitivesValue(self, field, value): + """"Prints short repeated primitives value.""" + # Note: this is called only when value has at least one element. + self._PrintFieldName(field) + self.out.write(' [') + for i in six.moves.range(len(value) - 1): + self.PrintFieldValue(field, value[i]) + self.out.write(', ') + self.PrintFieldValue(field, value[-1]) + self.out.write(']') + self.out.write(' ' if self.as_one_line else '\n') + + def _PrintMessageFieldValue(self, value): + if self.pointy_brackets: + openb = '<' + closeb = '>' + else: + openb = '{' + closeb = '}' + + if self.as_one_line: + self.out.write('%s ' % openb) + self.PrintMessage(value) + self.out.write(closeb) + else: + self.out.write('%s\n' % openb) + self.indent += 2 + self.PrintMessage(value) + self.indent -= 2 + self.out.write(' ' * self.indent + closeb) + + def PrintFieldValue(self, field, value): + """Print a single field value (not including name). + + For repeated fields, the value should be a single element. + + Args: + field: The descriptor of the field to be printed. + value: The value of the field. + """ + out = self.out + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + self._PrintMessageFieldValue(value) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: + enum_value = field.enum_type.values_by_number.get(value, None) + if enum_value is not None: + out.write(enum_value.name) + else: + out.write(str(value)) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: + out.write('\"') + if isinstance(value, six.text_type) and (six.PY2 or not self.as_utf8): + out_value = value.encode('utf-8') + else: + out_value = value + if field.type == descriptor.FieldDescriptor.TYPE_BYTES: + # We always need to escape all binary data in TYPE_BYTES fields. + out_as_utf8 = False + else: + out_as_utf8 = self.as_utf8 + out.write(text_encoding.CEscape(out_value, out_as_utf8)) + out.write('\"') + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: + if value: + out.write('true') + else: + out.write('false') + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_FLOAT: + if self.float_format is not None: + out.write('{1:{0}}'.format(self.float_format, value)) + else: + if math.isnan(value): + out.write(str(value)) + else: + out.write(str(type_checkers.ToShortestFloat(value))) + elif (field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_DOUBLE and + self.double_format is not None): + out.write('{1:{0}}'.format(self.double_format, value)) + else: + out.write(str(value)) + + +def Parse(text, + message, + allow_unknown_extension=False, + allow_field_number=False, + descriptor_pool=None, + allow_unknown_field=False): + """Parses a text representation of a protocol message into a message. + + NOTE: for historical reasons this function does not clear the input + message. This is different from what the binary msg.ParseFrom(...) does. + If text contains a field already set in message, the value is appended if the + field is repeated. Otherwise, an error is raised. + + Example:: + + a = MyProto() + a.repeated_field.append('test') + b = MyProto() + + # Repeated fields are combined + text_format.Parse(repr(a), b) + text_format.Parse(repr(a), b) # repeated_field contains ["test", "test"] + + # Non-repeated fields cannot be overwritten + a.singular_field = 1 + b.singular_field = 2 + text_format.Parse(repr(a), b) # ParseError + + # Binary version: + b.ParseFromString(a.SerializeToString()) # repeated_field is now "test" + + Caller is responsible for clearing the message as needed. + + Args: + text (str): Message text representation. + message (Message): A protocol buffer message to merge into. + allow_unknown_extension: if True, skip over missing extensions and keep + parsing + allow_field_number: if True, both field number and field name are allowed. + descriptor_pool (DescriptorPool): Descriptor pool used to resolve Any types. + allow_unknown_field: if True, skip over unknown field and keep + parsing. Avoid to use this option if possible. It may hide some + errors (e.g. spelling error on field name) + + Returns: + Message: The same message passed as argument. + + Raises: + ParseError: On text parsing problems. + """ + return ParseLines(text.split(b'\n' if isinstance(text, bytes) else u'\n'), + message, + allow_unknown_extension, + allow_field_number, + descriptor_pool=descriptor_pool, + allow_unknown_field=allow_unknown_field) + + +def Merge(text, + message, + allow_unknown_extension=False, + allow_field_number=False, + descriptor_pool=None, + allow_unknown_field=False): + """Parses a text representation of a protocol message into a message. + + Like Parse(), but allows repeated values for a non-repeated field, and uses + the last one. This means any non-repeated, top-level fields specified in text + replace those in the message. + + Args: + text (str): Message text representation. + message (Message): A protocol buffer message to merge into. + allow_unknown_extension: if True, skip over missing extensions and keep + parsing + allow_field_number: if True, both field number and field name are allowed. + descriptor_pool (DescriptorPool): Descriptor pool used to resolve Any types. + allow_unknown_field: if True, skip over unknown field and keep + parsing. Avoid to use this option if possible. It may hide some + errors (e.g. spelling error on field name) + + Returns: + Message: The same message passed as argument. + + Raises: + ParseError: On text parsing problems. + """ + return MergeLines( + text.split(b'\n' if isinstance(text, bytes) else u'\n'), + message, + allow_unknown_extension, + allow_field_number, + descriptor_pool=descriptor_pool, + allow_unknown_field=allow_unknown_field) + + +def ParseLines(lines, + message, + allow_unknown_extension=False, + allow_field_number=False, + descriptor_pool=None, + allow_unknown_field=False): + """Parses a text representation of a protocol message into a message. + + See Parse() for caveats. + + Args: + lines: An iterable of lines of a message's text representation. + message: A protocol buffer message to merge into. + allow_unknown_extension: if True, skip over missing extensions and keep + parsing + allow_field_number: if True, both field number and field name are allowed. + descriptor_pool: A DescriptorPool used to resolve Any types. + allow_unknown_field: if True, skip over unknown field and keep + parsing. Avoid to use this option if possible. It may hide some + errors (e.g. spelling error on field name) + + Returns: + The same message passed as argument. + + Raises: + ParseError: On text parsing problems. + """ + parser = _Parser(allow_unknown_extension, + allow_field_number, + descriptor_pool=descriptor_pool, + allow_unknown_field=allow_unknown_field) + return parser.ParseLines(lines, message) + + +def MergeLines(lines, + message, + allow_unknown_extension=False, + allow_field_number=False, + descriptor_pool=None, + allow_unknown_field=False): + """Parses a text representation of a protocol message into a message. + + See Merge() for more details. + + Args: + lines: An iterable of lines of a message's text representation. + message: A protocol buffer message to merge into. + allow_unknown_extension: if True, skip over missing extensions and keep + parsing + allow_field_number: if True, both field number and field name are allowed. + descriptor_pool: A DescriptorPool used to resolve Any types. + allow_unknown_field: if True, skip over unknown field and keep + parsing. Avoid to use this option if possible. It may hide some + errors (e.g. spelling error on field name) + + Returns: + The same message passed as argument. + + Raises: + ParseError: On text parsing problems. + """ + parser = _Parser(allow_unknown_extension, + allow_field_number, + descriptor_pool=descriptor_pool, + allow_unknown_field=allow_unknown_field) + return parser.MergeLines(lines, message) + + +class _Parser(object): + """Text format parser for protocol message.""" + + def __init__(self, + allow_unknown_extension=False, + allow_field_number=False, + descriptor_pool=None, + allow_unknown_field=False): + self.allow_unknown_extension = allow_unknown_extension + self.allow_field_number = allow_field_number + self.descriptor_pool = descriptor_pool + self.allow_unknown_field = allow_unknown_field + + def ParseLines(self, lines, message): + """Parses a text representation of a protocol message into a message.""" + self._allow_multiple_scalars = False + self._ParseOrMerge(lines, message) + return message + + def MergeLines(self, lines, message): + """Merges a text representation of a protocol message into a message.""" + self._allow_multiple_scalars = True + self._ParseOrMerge(lines, message) + return message + + def _ParseOrMerge(self, lines, message): + """Converts a text representation of a protocol message into a message. + + Args: + lines: Lines of a message's text representation. + message: A protocol buffer message to merge into. + + Raises: + ParseError: On text parsing problems. + """ + # Tokenize expects native str lines. + if six.PY2: + str_lines = (line if isinstance(line, str) else line.encode('utf-8') + for line in lines) + else: + str_lines = (line if isinstance(line, str) else line.decode('utf-8') + for line in lines) + tokenizer = Tokenizer(str_lines) + while not tokenizer.AtEnd(): + self._MergeField(tokenizer, message) + + def _MergeField(self, tokenizer, message): + """Merges a single protocol message field into a message. + + Args: + tokenizer: A tokenizer to parse the field name and values. + message: A protocol message to record the data. + + Raises: + ParseError: In case of text parsing problems. + """ + message_descriptor = message.DESCRIPTOR + if (message_descriptor.full_name == _ANY_FULL_TYPE_NAME and + tokenizer.TryConsume('[')): + type_url_prefix, packed_type_name = self._ConsumeAnyTypeUrl(tokenizer) + tokenizer.Consume(']') + tokenizer.TryConsume(':') + if tokenizer.TryConsume('<'): + expanded_any_end_token = '>' + else: + tokenizer.Consume('{') + expanded_any_end_token = '}' + expanded_any_sub_message = _BuildMessageFromTypeName(packed_type_name, + self.descriptor_pool) + if not expanded_any_sub_message: + raise ParseError('Type %s not found in descriptor pool' % + packed_type_name) + while not tokenizer.TryConsume(expanded_any_end_token): + if tokenizer.AtEnd(): + raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % + (expanded_any_end_token,)) + self._MergeField(tokenizer, expanded_any_sub_message) + deterministic = False + + message.Pack(expanded_any_sub_message, + type_url_prefix=type_url_prefix, + deterministic=deterministic) + return + + if tokenizer.TryConsume('['): + name = [tokenizer.ConsumeIdentifier()] + while tokenizer.TryConsume('.'): + name.append(tokenizer.ConsumeIdentifier()) + name = '.'.join(name) + + if not message_descriptor.is_extendable: + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" does not have extensions.' % + message_descriptor.full_name) + # pylint: disable=protected-access + field = message.Extensions._FindExtensionByName(name) + # pylint: enable=protected-access + + + if not field: + if self.allow_unknown_extension: + field = None + else: + raise tokenizer.ParseErrorPreviousToken( + 'Extension "%s" not registered. ' + 'Did you import the _pb2 module which defines it? ' + 'If you are trying to place the extension in the MessageSet ' + 'field of another message that is in an Any or MessageSet field, ' + 'that message\'s _pb2 module must be imported as well' % name) + elif message_descriptor != field.containing_type: + raise tokenizer.ParseErrorPreviousToken( + 'Extension "%s" does not extend message type "%s".' % + (name, message_descriptor.full_name)) + + tokenizer.Consume(']') + + else: + name = tokenizer.ConsumeIdentifierOrNumber() + if self.allow_field_number and name.isdigit(): + number = ParseInteger(name, True, True) + field = message_descriptor.fields_by_number.get(number, None) + if not field and message_descriptor.is_extendable: + field = message.Extensions._FindExtensionByNumber(number) + else: + field = message_descriptor.fields_by_name.get(name, None) + + # Group names are expected to be capitalized as they appear in the + # .proto file, which actually matches their type names, not their field + # names. + if not field: + field = message_descriptor.fields_by_name.get(name.lower(), None) + if field and field.type != descriptor.FieldDescriptor.TYPE_GROUP: + field = None + + if (field and field.type == descriptor.FieldDescriptor.TYPE_GROUP and + field.message_type.name != name): + field = None + + if not field and not self.allow_unknown_field: + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" has no field named "%s".' % + (message_descriptor.full_name, name)) + + if field: + if not self._allow_multiple_scalars and field.containing_oneof: + # Check if there's a different field set in this oneof. + # Note that we ignore the case if the same field was set before, and we + # apply _allow_multiple_scalars to non-scalar fields as well. + which_oneof = message.WhichOneof(field.containing_oneof.name) + if which_oneof is not None and which_oneof != field.name: + raise tokenizer.ParseErrorPreviousToken( + 'Field "%s" is specified along with field "%s", another member ' + 'of oneof "%s" for message type "%s".' % + (field.name, which_oneof, field.containing_oneof.name, + message_descriptor.full_name)) + + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + tokenizer.TryConsume(':') + merger = self._MergeMessageField + else: + tokenizer.Consume(':') + merger = self._MergeScalarField + + if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED and + tokenizer.TryConsume('[')): + # Short repeated format, e.g. "foo: [1, 2, 3]" + if not tokenizer.TryConsume(']'): + while True: + merger(tokenizer, message, field) + if tokenizer.TryConsume(']'): + break + tokenizer.Consume(',') + + else: + merger(tokenizer, message, field) + + else: # Proto field is unknown. + assert (self.allow_unknown_extension or self.allow_unknown_field) + _SkipFieldContents(tokenizer) + + # For historical reasons, fields may optionally be separated by commas or + # semicolons. + if not tokenizer.TryConsume(','): + tokenizer.TryConsume(';') + + + def _ConsumeAnyTypeUrl(self, tokenizer): + """Consumes a google.protobuf.Any type URL and returns the type name.""" + # Consume "type.googleapis.com/". + prefix = [tokenizer.ConsumeIdentifier()] + tokenizer.Consume('.') + prefix.append(tokenizer.ConsumeIdentifier()) + tokenizer.Consume('.') + prefix.append(tokenizer.ConsumeIdentifier()) + tokenizer.Consume('/') + # Consume the fully-qualified type name. + name = [tokenizer.ConsumeIdentifier()] + while tokenizer.TryConsume('.'): + name.append(tokenizer.ConsumeIdentifier()) + return '.'.join(prefix), '.'.join(name) + + def _MergeMessageField(self, tokenizer, message, field): + """Merges a single scalar field into a message. + + Args: + tokenizer: A tokenizer to parse the field value. + message: The message of which field is a member. + field: The descriptor of the field to be merged. + + Raises: + ParseError: In case of text parsing problems. + """ + is_map_entry = _IsMapEntry(field) + + if tokenizer.TryConsume('<'): + end_token = '>' + else: + tokenizer.Consume('{') + end_token = '}' + + if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + if field.is_extension: + sub_message = message.Extensions[field].add() + elif is_map_entry: + sub_message = getattr(message, field.name).GetEntryClass()() + else: + sub_message = getattr(message, field.name).add() + else: + if field.is_extension: + if (not self._allow_multiple_scalars and + message.HasExtension(field)): + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" should not have multiple "%s" extensions.' % + (message.DESCRIPTOR.full_name, field.full_name)) + sub_message = message.Extensions[field] + else: + # Also apply _allow_multiple_scalars to message field. + # TODO(jieluo): Change to _allow_singular_overwrites. + if (not self._allow_multiple_scalars and + message.HasField(field.name)): + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" should not have multiple "%s" fields.' % + (message.DESCRIPTOR.full_name, field.name)) + sub_message = getattr(message, field.name) + sub_message.SetInParent() + + while not tokenizer.TryConsume(end_token): + if tokenizer.AtEnd(): + raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token,)) + self._MergeField(tokenizer, sub_message) + + if is_map_entry: + value_cpptype = field.message_type.fields_by_name['value'].cpp_type + if value_cpptype == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + value = getattr(message, field.name)[sub_message.key] + value.CopyFrom(sub_message.value) + else: + getattr(message, field.name)[sub_message.key] = sub_message.value + + @staticmethod + def _IsProto3Syntax(message): + message_descriptor = message.DESCRIPTOR + return (hasattr(message_descriptor, 'syntax') and + message_descriptor.syntax == 'proto3') + + def _MergeScalarField(self, tokenizer, message, field): + """Merges a single scalar field into a message. + + Args: + tokenizer: A tokenizer to parse the field value. + message: A protocol message to record the data. + field: The descriptor of the field to be merged. + + Raises: + ParseError: In case of text parsing problems. + RuntimeError: On runtime errors. + """ + _ = self.allow_unknown_extension + value = None + + if field.type in (descriptor.FieldDescriptor.TYPE_INT32, + descriptor.FieldDescriptor.TYPE_SINT32, + descriptor.FieldDescriptor.TYPE_SFIXED32): + value = _ConsumeInt32(tokenizer) + elif field.type in (descriptor.FieldDescriptor.TYPE_INT64, + descriptor.FieldDescriptor.TYPE_SINT64, + descriptor.FieldDescriptor.TYPE_SFIXED64): + value = _ConsumeInt64(tokenizer) + elif field.type in (descriptor.FieldDescriptor.TYPE_UINT32, + descriptor.FieldDescriptor.TYPE_FIXED32): + value = _ConsumeUint32(tokenizer) + elif field.type in (descriptor.FieldDescriptor.TYPE_UINT64, + descriptor.FieldDescriptor.TYPE_FIXED64): + value = _ConsumeUint64(tokenizer) + elif field.type in (descriptor.FieldDescriptor.TYPE_FLOAT, + descriptor.FieldDescriptor.TYPE_DOUBLE): + value = tokenizer.ConsumeFloat() + elif field.type == descriptor.FieldDescriptor.TYPE_BOOL: + value = tokenizer.ConsumeBool() + elif field.type == descriptor.FieldDescriptor.TYPE_STRING: + value = tokenizer.ConsumeString() + elif field.type == descriptor.FieldDescriptor.TYPE_BYTES: + value = tokenizer.ConsumeByteString() + elif field.type == descriptor.FieldDescriptor.TYPE_ENUM: + value = tokenizer.ConsumeEnum(field) + else: + raise RuntimeError('Unknown field type %d' % field.type) + + if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + if field.is_extension: + message.Extensions[field].append(value) + else: + getattr(message, field.name).append(value) + else: + if field.is_extension: + if (not self._allow_multiple_scalars and + not self._IsProto3Syntax(message) and + message.HasExtension(field)): + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" should not have multiple "%s" extensions.' % + (message.DESCRIPTOR.full_name, field.full_name)) + else: + message.Extensions[field] = value + else: + duplicate_error = False + if not self._allow_multiple_scalars: + if self._IsProto3Syntax(message): + # Proto3 doesn't represent presence so we try best effort to check + # multiple scalars by compare to default values. + duplicate_error = bool(getattr(message, field.name)) + else: + duplicate_error = message.HasField(field.name) + + if duplicate_error: + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" should not have multiple "%s" fields.' % + (message.DESCRIPTOR.full_name, field.name)) + else: + setattr(message, field.name, value) + + +def _SkipFieldContents(tokenizer): + """Skips over contents (value or message) of a field. + + Args: + tokenizer: A tokenizer to parse the field name and values. + """ + # Try to guess the type of this field. + # If this field is not a message, there should be a ":" between the + # field name and the field value and also the field value should not + # start with "{" or "<" which indicates the beginning of a message body. + # If there is no ":" or there is a "{" or "<" after ":", this field has + # to be a message or the input is ill-formed. + if tokenizer.TryConsume(':') and not tokenizer.LookingAt( + '{') and not tokenizer.LookingAt('<'): + _SkipFieldValue(tokenizer) + else: + _SkipFieldMessage(tokenizer) + + +def _SkipField(tokenizer): + """Skips over a complete field (name and value/message). + + Args: + tokenizer: A tokenizer to parse the field name and values. + """ + if tokenizer.TryConsume('['): + # Consume extension name. + tokenizer.ConsumeIdentifier() + while tokenizer.TryConsume('.'): + tokenizer.ConsumeIdentifier() + tokenizer.Consume(']') + else: + tokenizer.ConsumeIdentifierOrNumber() + + _SkipFieldContents(tokenizer) + + # For historical reasons, fields may optionally be separated by commas or + # semicolons. + if not tokenizer.TryConsume(','): + tokenizer.TryConsume(';') + + +def _SkipFieldMessage(tokenizer): + """Skips over a field message. + + Args: + tokenizer: A tokenizer to parse the field name and values. + """ + + if tokenizer.TryConsume('<'): + delimiter = '>' + else: + tokenizer.Consume('{') + delimiter = '}' + + while not tokenizer.LookingAt('>') and not tokenizer.LookingAt('}'): + _SkipField(tokenizer) + + tokenizer.Consume(delimiter) + + +def _SkipFieldValue(tokenizer): + """Skips over a field value. + + Args: + tokenizer: A tokenizer to parse the field name and values. + + Raises: + ParseError: In case an invalid field value is found. + """ + # String/bytes tokens can come in multiple adjacent string literals. + # If we can consume one, consume as many as we can. + if tokenizer.TryConsumeByteString(): + while tokenizer.TryConsumeByteString(): + pass + return + + if (not tokenizer.TryConsumeIdentifier() and + not _TryConsumeInt64(tokenizer) and not _TryConsumeUint64(tokenizer) and + not tokenizer.TryConsumeFloat()): + raise ParseError('Invalid field value: ' + tokenizer.token) + + +class Tokenizer(object): + """Protocol buffer text representation tokenizer. + + This class handles the lower level string parsing by splitting it into + meaningful tokens. + + It was directly ported from the Java protocol buffer API. + """ + + _WHITESPACE = re.compile(r'\s+') + _COMMENT = re.compile(r'(\s*#.*$)', re.MULTILINE) + _WHITESPACE_OR_COMMENT = re.compile(r'(\s|(#.*$))+', re.MULTILINE) + _TOKEN = re.compile('|'.join([ + r'[a-zA-Z_][0-9a-zA-Z_+-]*', # an identifier + r'([0-9+-]|(\.[0-9]))[0-9a-zA-Z_.+-]*', # a number + ] + [ # quoted str for each quote mark + # Avoid backtracking! https://stackoverflow.com/a/844267 + r'{qt}[^{qt}\n\\]*((\\.)+[^{qt}\n\\]*)*({qt}|\\?$)'.format(qt=mark) + for mark in _QUOTES + ])) + + _IDENTIFIER = re.compile(r'[^\d\W]\w*') + _IDENTIFIER_OR_NUMBER = re.compile(r'\w+') + + def __init__(self, lines, skip_comments=True): + self._position = 0 + self._line = -1 + self._column = 0 + self._token_start = None + self.token = '' + self._lines = iter(lines) + self._current_line = '' + self._previous_line = 0 + self._previous_column = 0 + self._more_lines = True + self._skip_comments = skip_comments + self._whitespace_pattern = (skip_comments and self._WHITESPACE_OR_COMMENT + or self._WHITESPACE) + self._SkipWhitespace() + self.NextToken() + + def LookingAt(self, token): + return self.token == token + + def AtEnd(self): + """Checks the end of the text was reached. + + Returns: + True iff the end was reached. + """ + return not self.token + + def _PopLine(self): + while len(self._current_line) <= self._column: + try: + self._current_line = next(self._lines) + except StopIteration: + self._current_line = '' + self._more_lines = False + return + else: + self._line += 1 + self._column = 0 + + def _SkipWhitespace(self): + while True: + self._PopLine() + match = self._whitespace_pattern.match(self._current_line, self._column) + if not match: + break + length = len(match.group(0)) + self._column += length + + def TryConsume(self, token): + """Tries to consume a given piece of text. + + Args: + token: Text to consume. + + Returns: + True iff the text was consumed. + """ + if self.token == token: + self.NextToken() + return True + return False + + def Consume(self, token): + """Consumes a piece of text. + + Args: + token: Text to consume. + + Raises: + ParseError: If the text couldn't be consumed. + """ + if not self.TryConsume(token): + raise self.ParseError('Expected "%s".' % token) + + def ConsumeComment(self): + result = self.token + if not self._COMMENT.match(result): + raise self.ParseError('Expected comment.') + self.NextToken() + return result + + def ConsumeCommentOrTrailingComment(self): + """Consumes a comment, returns a 2-tuple (trailing bool, comment str).""" + + # Tokenizer initializes _previous_line and _previous_column to 0. As the + # tokenizer starts, it looks like there is a previous token on the line. + just_started = self._line == 0 and self._column == 0 + + before_parsing = self._previous_line + comment = self.ConsumeComment() + + # A trailing comment is a comment on the same line than the previous token. + trailing = (self._previous_line == before_parsing + and not just_started) + + return trailing, comment + + def TryConsumeIdentifier(self): + try: + self.ConsumeIdentifier() + return True + except ParseError: + return False + + def ConsumeIdentifier(self): + """Consumes protocol message field identifier. + + Returns: + Identifier string. + + Raises: + ParseError: If an identifier couldn't be consumed. + """ + result = self.token + if not self._IDENTIFIER.match(result): + raise self.ParseError('Expected identifier.') + self.NextToken() + return result + + def TryConsumeIdentifierOrNumber(self): + try: + self.ConsumeIdentifierOrNumber() + return True + except ParseError: + return False + + def ConsumeIdentifierOrNumber(self): + """Consumes protocol message field identifier. + + Returns: + Identifier string. + + Raises: + ParseError: If an identifier couldn't be consumed. + """ + result = self.token + if not self._IDENTIFIER_OR_NUMBER.match(result): + raise self.ParseError('Expected identifier or number, got %s.' % result) + self.NextToken() + return result + + def TryConsumeInteger(self): + try: + # Note: is_long only affects value type, not whether an error is raised. + self.ConsumeInteger() + return True + except ParseError: + return False + + def ConsumeInteger(self, is_long=False): + """Consumes an integer number. + + Args: + is_long: True if the value should be returned as a long integer. + Returns: + The integer parsed. + + Raises: + ParseError: If an integer couldn't be consumed. + """ + try: + result = _ParseAbstractInteger(self.token, is_long=is_long) + except ValueError as e: + raise self.ParseError(str(e)) + self.NextToken() + return result + + def TryConsumeFloat(self): + try: + self.ConsumeFloat() + return True + except ParseError: + return False + + def ConsumeFloat(self): + """Consumes an floating point number. + + Returns: + The number parsed. + + Raises: + ParseError: If a floating point number couldn't be consumed. + """ + try: + result = ParseFloat(self.token) + except ValueError as e: + raise self.ParseError(str(e)) + self.NextToken() + return result + + def ConsumeBool(self): + """Consumes a boolean value. + + Returns: + The bool parsed. + + Raises: + ParseError: If a boolean value couldn't be consumed. + """ + try: + result = ParseBool(self.token) + except ValueError as e: + raise self.ParseError(str(e)) + self.NextToken() + return result + + def TryConsumeByteString(self): + try: + self.ConsumeByteString() + return True + except ParseError: + return False + + def ConsumeString(self): + """Consumes a string value. + + Returns: + The string parsed. + + Raises: + ParseError: If a string value couldn't be consumed. + """ + the_bytes = self.ConsumeByteString() + try: + return six.text_type(the_bytes, 'utf-8') + except UnicodeDecodeError as e: + raise self._StringParseError(e) + + def ConsumeByteString(self): + """Consumes a byte array value. + + Returns: + The array parsed (as a string). + + Raises: + ParseError: If a byte array value couldn't be consumed. + """ + the_list = [self._ConsumeSingleByteString()] + while self.token and self.token[0] in _QUOTES: + the_list.append(self._ConsumeSingleByteString()) + return b''.join(the_list) + + def _ConsumeSingleByteString(self): + """Consume one token of a string literal. + + String literals (whether bytes or text) can come in multiple adjacent + tokens which are automatically concatenated, like in C or Python. This + method only consumes one token. + + Returns: + The token parsed. + Raises: + ParseError: When the wrong format data is found. + """ + text = self.token + if len(text) < 1 or text[0] not in _QUOTES: + raise self.ParseError('Expected string but found: %r' % (text,)) + + if len(text) < 2 or text[-1] != text[0]: + raise self.ParseError('String missing ending quote: %r' % (text,)) + + try: + result = text_encoding.CUnescape(text[1:-1]) + except ValueError as e: + raise self.ParseError(str(e)) + self.NextToken() + return result + + def ConsumeEnum(self, field): + try: + result = ParseEnum(field, self.token) + except ValueError as e: + raise self.ParseError(str(e)) + self.NextToken() + return result + + def ParseErrorPreviousToken(self, message): + """Creates and *returns* a ParseError for the previously read token. + + Args: + message: A message to set for the exception. + + Returns: + A ParseError instance. + """ + return ParseError(message, self._previous_line + 1, + self._previous_column + 1) + + def ParseError(self, message): + """Creates and *returns* a ParseError for the current token.""" + return ParseError('\'' + self._current_line + '\': ' + message, + self._line + 1, self._column + 1) + + def _StringParseError(self, e): + return self.ParseError('Couldn\'t parse string: ' + str(e)) + + def NextToken(self): + """Reads the next meaningful token.""" + self._previous_line = self._line + self._previous_column = self._column + + self._column += len(self.token) + self._SkipWhitespace() + + if not self._more_lines: + self.token = '' + return + + match = self._TOKEN.match(self._current_line, self._column) + if not match and not self._skip_comments: + match = self._COMMENT.match(self._current_line, self._column) + if match: + token = match.group(0) + self.token = token + else: + self.token = self._current_line[self._column] + +# Aliased so it can still be accessed by current visibility violators. +# TODO(dbarnett): Migrate violators to textformat_tokenizer. +_Tokenizer = Tokenizer # pylint: disable=invalid-name + + +def _ConsumeInt32(tokenizer): + """Consumes a signed 32bit integer number from tokenizer. + + Args: + tokenizer: A tokenizer used to parse the number. + + Returns: + The integer parsed. + + Raises: + ParseError: If a signed 32bit integer couldn't be consumed. + """ + return _ConsumeInteger(tokenizer, is_signed=True, is_long=False) + + +def _ConsumeUint32(tokenizer): + """Consumes an unsigned 32bit integer number from tokenizer. + + Args: + tokenizer: A tokenizer used to parse the number. + + Returns: + The integer parsed. + + Raises: + ParseError: If an unsigned 32bit integer couldn't be consumed. + """ + return _ConsumeInteger(tokenizer, is_signed=False, is_long=False) + + +def _TryConsumeInt64(tokenizer): + try: + _ConsumeInt64(tokenizer) + return True + except ParseError: + return False + + +def _ConsumeInt64(tokenizer): + """Consumes a signed 32bit integer number from tokenizer. + + Args: + tokenizer: A tokenizer used to parse the number. + + Returns: + The integer parsed. + + Raises: + ParseError: If a signed 32bit integer couldn't be consumed. + """ + return _ConsumeInteger(tokenizer, is_signed=True, is_long=True) + + +def _TryConsumeUint64(tokenizer): + try: + _ConsumeUint64(tokenizer) + return True + except ParseError: + return False + + +def _ConsumeUint64(tokenizer): + """Consumes an unsigned 64bit integer number from tokenizer. + + Args: + tokenizer: A tokenizer used to parse the number. + + Returns: + The integer parsed. + + Raises: + ParseError: If an unsigned 64bit integer couldn't be consumed. + """ + return _ConsumeInteger(tokenizer, is_signed=False, is_long=True) + + +def _TryConsumeInteger(tokenizer, is_signed=False, is_long=False): + try: + _ConsumeInteger(tokenizer, is_signed=is_signed, is_long=is_long) + return True + except ParseError: + return False + + +def _ConsumeInteger(tokenizer, is_signed=False, is_long=False): + """Consumes an integer number from tokenizer. + + Args: + tokenizer: A tokenizer used to parse the number. + is_signed: True if a signed integer must be parsed. + is_long: True if a long integer must be parsed. + + Returns: + The integer parsed. + + Raises: + ParseError: If an integer with given characteristics couldn't be consumed. + """ + try: + result = ParseInteger(tokenizer.token, is_signed=is_signed, is_long=is_long) + except ValueError as e: + raise tokenizer.ParseError(str(e)) + tokenizer.NextToken() + return result + + +def ParseInteger(text, is_signed=False, is_long=False): + """Parses an integer. + + Args: + text: The text to parse. + is_signed: True if a signed integer must be parsed. + is_long: True if a long integer must be parsed. + + Returns: + The integer value. + + Raises: + ValueError: Thrown Iff the text is not a valid integer. + """ + # Do the actual parsing. Exception handling is propagated to caller. + result = _ParseAbstractInteger(text, is_long=is_long) + + # Check if the integer is sane. Exceptions handled by callers. + checker = _INTEGER_CHECKERS[2 * int(is_long) + int(is_signed)] + checker.CheckValue(result) + return result + + +def _ParseAbstractInteger(text, is_long=False): + """Parses an integer without checking size/signedness. + + Args: + text: The text to parse. + is_long: True if the value should be returned as a long integer. + + Returns: + The integer value. + + Raises: + ValueError: Thrown Iff the text is not a valid integer. + """ + # Do the actual parsing. Exception handling is propagated to caller. + orig_text = text + c_octal_match = re.match(r'(-?)0(\d+)$', text) + if c_octal_match: + # Python 3 no longer supports 0755 octal syntax without the 'o', so + # we always use the '0o' prefix for multi-digit numbers starting with 0. + text = c_octal_match.group(1) + '0o' + c_octal_match.group(2) + try: + # We force 32-bit values to int and 64-bit values to long to make + # alternate implementations where the distinction is more significant + # (e.g. the C++ implementation) simpler. + if is_long: + return long(text, 0) + else: + return int(text, 0) + except ValueError: + raise ValueError('Couldn\'t parse integer: %s' % orig_text) + + +def ParseFloat(text): + """Parse a floating point number. + + Args: + text: Text to parse. + + Returns: + The number parsed. + + Raises: + ValueError: If a floating point number couldn't be parsed. + """ + try: + # Assume Python compatible syntax. + return float(text) + except ValueError: + # Check alternative spellings. + if _FLOAT_INFINITY.match(text): + if text[0] == '-': + return float('-inf') + else: + return float('inf') + elif _FLOAT_NAN.match(text): + return float('nan') + else: + # assume '1.0f' format + try: + return float(text.rstrip('f')) + except ValueError: + raise ValueError('Couldn\'t parse float: %s' % text) + + +def ParseBool(text): + """Parse a boolean value. + + Args: + text: Text to parse. + + Returns: + Boolean values parsed + + Raises: + ValueError: If text is not a valid boolean. + """ + if text in ('true', 't', '1', 'True'): + return True + elif text in ('false', 'f', '0', 'False'): + return False + else: + raise ValueError('Expected "true" or "false".') + + +def ParseEnum(field, value): + """Parse an enum value. + + The value can be specified by a number (the enum value), or by + a string literal (the enum name). + + Args: + field: Enum field descriptor. + value: String value. + + Returns: + Enum value number. + + Raises: + ValueError: If the enum value could not be parsed. + """ + enum_descriptor = field.enum_type + try: + number = int(value, 0) + except ValueError: + # Identifier. + enum_value = enum_descriptor.values_by_name.get(value, None) + if enum_value is None: + raise ValueError('Enum type "%s" has no value named %s.' % + (enum_descriptor.full_name, value)) + else: + # Numeric value. + if hasattr(field.file, 'syntax'): + # Attribute is checked for compatibility. + if field.file.syntax == 'proto3': + # Proto3 accept numeric unknown enums. + return number + enum_value = enum_descriptor.values_by_number.get(number, None) + if enum_value is None: + raise ValueError('Enum type "%s" has no value with number %d.' % + (enum_descriptor.full_name, number)) + return enum_value.number diff --git a/contrib/python/protobuf/py3/google/protobuf/util/__init__.py b/contrib/python/protobuf/py3/google/protobuf/util/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/protobuf/py3/google/protobuf/util/__init__.py diff --git a/contrib/python/protobuf/py3/ya.make b/contrib/python/protobuf/py3/ya.make new file mode 100644 index 0000000000..28099779e1 --- /dev/null +++ b/contrib/python/protobuf/py3/ya.make @@ -0,0 +1,95 @@ +# Generated by devtools/yamaker. + +PY3_LIBRARY() + +LICENSE(BSD-3-Clause) + +LICENSE_TEXTS(.yandex_meta/licenses.list.txt) + +OWNER( + g:cpp-committee + g:cpp-contrib +) + +VERSION(3.17.3) + +ORIGINAL_SOURCE(https://github.com/protocolbuffers/protobuf/archive/v3.17.3.tar.gz) + +PEERDIR( + contrib/libs/protobuf + contrib/libs/protobuf/builtin_proto/protos_from_protobuf + contrib/libs/protobuf/builtin_proto/protos_from_protoc + contrib/python/six +) + +ADDINCL( + contrib/python/protobuf/py3 +) + +NO_COMPILER_WARNINGS() + +NO_LINT() + +CFLAGS( + -DPYTHON_PROTO2_CPP_IMPL_V2 +) + +SRCS( + google/protobuf/internal/api_implementation.cc + google/protobuf/pyext/descriptor.cc + google/protobuf/pyext/descriptor_containers.cc + google/protobuf/pyext/descriptor_database.cc + google/protobuf/pyext/descriptor_pool.cc + google/protobuf/pyext/extension_dict.cc + google/protobuf/pyext/field.cc + google/protobuf/pyext/map_container.cc + google/protobuf/pyext/message.cc + google/protobuf/pyext/message_factory.cc + google/protobuf/pyext/message_module.cc + google/protobuf/pyext/repeated_composite_container.cc + google/protobuf/pyext/repeated_scalar_container.cc + google/protobuf/pyext/unknown_fields.cc +) + +PY_REGISTER( + google.protobuf.internal._api_implementation + google.protobuf.pyext._message +) + +PY_SRCS( + TOP_LEVEL + google/__init__.py + google/protobuf/__init__.py + google/protobuf/compiler/__init__.py + google/protobuf/descriptor.py + google/protobuf/descriptor_database.py + google/protobuf/descriptor_pool.py + google/protobuf/internal/__init__.py + google/protobuf/internal/_parameterized.py + google/protobuf/internal/api_implementation.py + google/protobuf/internal/containers.py + google/protobuf/internal/decoder.py + google/protobuf/internal/encoder.py + google/protobuf/internal/enum_type_wrapper.py + google/protobuf/internal/extension_dict.py + google/protobuf/internal/message_listener.py + google/protobuf/internal/python_message.py + google/protobuf/internal/type_checkers.py + google/protobuf/internal/well_known_types.py + google/protobuf/internal/wire_format.py + google/protobuf/json_format.py + google/protobuf/message.py + google/protobuf/message_factory.py + google/protobuf/proto_builder.py + google/protobuf/pyext/__init__.py + google/protobuf/pyext/cpp_message.py + google/protobuf/reflection.py + google/protobuf/service.py + google/protobuf/service_reflection.py + google/protobuf/symbol_database.py + google/protobuf/text_encoding.py + google/protobuf/text_format.py + google/protobuf/util/__init__.py +) + +END() |